Skip to content

Commit e8d24aa

Browse files
committed
Inferencer support parallel_executor
1 parent 2a63652 commit e8d24aa

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

python/paddle/fluid/inferencer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,36 @@
1717
import executor
1818
import framework
1919
import io
20+
import parallel_executor
2021
import unique_name
2122
from trainer import check_and_get_place
2223

2324
__all__ = ['Inferencer', ]
2425

2526

2627
class Inferencer(object):
27-
def __init__(self, infer_func, param_path, place=None):
28+
def __init__(self, infer_func, param_path, place=None, parallel=False):
2829
"""
2930
:param infer_func: a function that will return predict Variable
3031
:param param_path: the path where the inference model is saved by fluid.io.save_params
3132
:param place: place to do the inference
3233
"""
3334
self.param_path = param_path
3435
self.scope = core.Scope()
36+
self.parallel = parallel
37+
self.place = check_and_get_place(place)
3538

3639
self.inference_program = framework.Program()
3740
with framework.program_guard(self.inference_program):
3841
with unique_name.guard():
3942
self.predict_var = infer_func()
4043

41-
self.exe = executor.Executor(check_and_get_place(place))
44+
if parallel:
45+
self.exe = parallel_executor.ParallelExecutor(
46+
use_cuda=isinstance(self.place, core.CUDAPlace),
47+
loss_name=self.predict_var.name)
48+
else:
49+
self.exe = executor.Executor(self.place)
4250
with executor.scope_guard(self.scope):
4351
# load params from param_path into scope
4452
io.load_params(self.exe, param_path, self.inference_program)

python/paddle/fluid/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
1516
import os
17+
1618
import core
17-
import framework
18-
import executor
19+
1920
import data_feeder
20-
import contextlib
21+
import executor
22+
import framework
2123
import io
22-
import unique_name
23-
import parallel_executor
24-
2524
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
2625
import optimizer as opt_module
26+
import parallel_executor
2727
from transpiler import distribute_transpiler
2828

2929
__all__ = [

0 commit comments

Comments
 (0)