Skip to content

Commit 81c47b2

Browse files
committed
add type check and default scope
1 parent 01b88f2 commit 81c47b2

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

python/paddle/fluid/inference_transpiler.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,35 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16-
import os
17-
import shutil
16+
from framework import Program
17+
from executor import global_scope
1818
from . import core
1919

2020

2121
class InferenceTranspiler:
22-
def transpile(self, program, scope, place):
22+
def transpile(self, program, place, scope=None):
2323
'''
2424
Transpile the program. Support only fuse batch normalization now.
2525
2626
:param program: program to transpile
2727
:type program: Program
28-
:param scope: inference scope
29-
:type scope: Scope
3028
:param place: inference place
3129
:type place: Place
30+
:param scope: inference scope
31+
:type scope: Scope or None
3232
'''
33-
self.fuse_batch_norm(program, scope, place)
34-
35-
def fuse_batch_norm(self, program, scope, place):
33+
if not isinstance(program, Program):
34+
raise TypeError("program should be as Program type")
35+
if not isinstance(place, core.CPUPlace) and not isinstance(
36+
place, core.CUDAPlace):
37+
raise TypeError("place should be as CPUPlace/CUDAPlace type")
38+
if scope is None:
39+
scope = global_scope()
40+
if not isinstance(scope, core.Scope):
41+
raise TypeError("scope should be as Scope type or None")
42+
self.fuse_batch_norm(program, place, scope)
43+
44+
def fuse_batch_norm(self, program, place, scope):
3645
'''
3746
Transpile the program by fused batch normalization.
3847
@@ -66,10 +75,10 @@ def fuse_batch_norm(self, program, scope, place):
6675
6776
:param program: program to transpile
6877
:type program: Program
69-
:param scope: inference scope
70-
:type scope: Scope
7178
:param place: inference place
7279
:type place: Place
80+
:param scope: inference scope
81+
:type scope: Scope
7382
'''
7483
self.scope = scope
7584
self.place = place

python/paddle/fluid/tests/book/test_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None):
229229
# Use inference_transpiler to speedup
230230
inference_transpiler_program = inference_program.clone()
231231
t = fluid.InferenceTranspiler()
232-
t.transpile(inference_transpiler_program, inference_scope, place)
232+
t.transpile(inference_transpiler_program, place)
233233

234234
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
235235
# and results will contain a list of data corresponding to fetch_targets.

0 commit comments

Comments
 (0)