|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 |
| -import os |
17 |
| -import shutil |
| 16 | +from framework import Program |
| 17 | +from executor import global_scope |
18 | 18 | from . import core
|
19 | 19 |
|
20 | 20 |
|
21 | 21 | class InferenceTranspiler:
|
22 |
| - def transpile(self, program, scope, place): |
| 22 | + def transpile(self, program, place, scope=None): |
23 | 23 | '''
|
24 | 24 | Transpile the program. Support only fuse batch normalization now.
|
25 | 25 |
|
26 | 26 | :param program: program to transpile
|
27 | 27 | :type program: Program
|
28 |
| - :param scope: inference scope |
29 |
| - :type scope: Scope |
30 | 28 | :param place: inference place
|
31 | 29 | :type place: Place
|
| 30 | + :param scope: inference scope |
| 31 | + :type scope: Scope or None |
32 | 32 | '''
|
33 |
| - self.fuse_batch_norm(program, scope, place) |
34 |
| - |
35 |
| - def fuse_batch_norm(self, program, scope, place): |
| 33 | + if not isinstance(program, Program): |
| 34 | + raise TypeError("program should be as Program type") |
| 35 | + if not isinstance(place, core.CPUPlace) and not isinstance( |
| 36 | + place, core.CUDAPlace): |
| 37 | + raise TypeError("place should be as CPUPlace/CUDAPlace type") |
| 38 | + if scope is None: |
| 39 | + scope = global_scope() |
| 40 | + if not isinstance(scope, core.Scope): |
| 41 | + raise TypeError("scope should be as Scope type or None") |
| 42 | + self.fuse_batch_norm(program, place, scope) |
| 43 | + |
| 44 | + def fuse_batch_norm(self, program, place, scope): |
36 | 45 | '''
|
37 | 46 | Transpile the program by fused batch normalization.
|
38 | 47 |
|
@@ -66,10 +75,10 @@ def fuse_batch_norm(self, program, scope, place):
|
66 | 75 |
|
67 | 76 | :param program: program to transpile
|
68 | 77 | :type program: Program
|
69 |
| - :param scope: inference scope |
70 |
| - :type scope: Scope |
71 | 78 | :param place: inference place
|
72 | 79 | :type place: Place
|
| 80 | + :param scope: inference scope |
| 81 | + :type scope: Scope |
73 | 82 | '''
|
74 | 83 | self.scope = scope
|
75 | 84 | self.place = place
|
|
0 commit comments