Skip to content

Commit 82b8a3c

Browse files
committed
Move trainer to contrib
1 parent 3fbfcd9 commit 82b8a3c

File tree

5 files changed

+1374
-1351
lines changed

5 files changed

+1374
-1351
lines changed

python/paddle/fluid/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,8 @@
1919
# import all class inside executor into fluid module
2020
from . import executor
2121
from .executor import *
22-
2322
from . import trainer
24-
from .trainer import Trainer
25-
from .trainer import BeginEpochEvent
26-
from .trainer import EndEpochEvent
27-
from .trainer import BeginStepEvent
28-
from .trainer import EndStepEvent
29-
from .trainer import CheckpointConfig
30-
3123
from . import inferencer
32-
from .inferencer import Inferencer
3324

3425
from . import io
3526
from . import evaluator
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import contextlib
18+
19+
from .. import core
20+
21+
from .. import executor
22+
from .. import framework
23+
from .. import io
24+
from .. import parallel_executor
25+
from .. import unique_name
26+
from .trainer import check_and_get_place
27+
28+
__all__ = ['Inferencer', ]
29+
30+
31+
class Inferencer(object):
32+
"""
33+
Inferencer High Level API.
34+
35+
Args:
36+
infer_func (Python func): Infer function that will return predict Variable
37+
param_path (str): The path where the inference model is saved by fluid.io.save_params
38+
place (Place): place to do the inference
39+
parallel (bool): use parallel_executor to run the inference, it will use multi CPU/GPU.
40+
41+
Examples:
42+
.. code-block:: python
43+
44+
def inference_program():
45+
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
46+
y_predict = fluid.layers.fc(input=x, size=1, act=None)
47+
return y_predict
48+
49+
place = fluid.CPUPlace()
50+
inferencer = fluid.Inferencer(
51+
infer_func=inference_program, param_path="/tmp/model", place=place)
52+
53+
"""
54+
55+
def __init__(self, infer_func, param_path, place=None, parallel=False):
56+
self.param_path = param_path
57+
self.scope = core.Scope()
58+
self.parallel = parallel
59+
self.place = check_and_get_place(place)
60+
61+
self.inference_program = framework.Program()
62+
with framework.program_guard(self.inference_program):
63+
with unique_name.guard():
64+
self.predict_var = infer_func()
65+
66+
with self._prog_and_scope_guard():
67+
# load params from param_path into scope
68+
io.load_params(executor.Executor(self.place), param_path)
69+
70+
if parallel:
71+
with self._prog_and_scope_guard():
72+
self.exe = parallel_executor.ParallelExecutor(
73+
use_cuda=isinstance(self.place, core.CUDAPlace),
74+
loss_name=self.predict_var.name)
75+
else:
76+
self.exe = executor.Executor(self.place)
77+
78+
self.inference_program = self.inference_program.clone(for_test=True)
79+
80+
def infer(self, inputs, return_numpy=True):
81+
"""
82+
Do Inference for Inputs
83+
84+
Args:
85+
inputs (map): a map of {"input_name": input_var} that will be feed into the inference program
86+
return_numpy (bool): transform return value into numpy or not
87+
88+
Returns:
89+
Tensor or Numpy: the predict value of the inference model for the inputs
90+
91+
Examples:
92+
.. code-block:: python
93+
94+
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
95+
results = inferencer.infer({'x': tensor_x})
96+
"""
97+
if not isinstance(inputs, dict):
98+
raise ValueError(
99+
"inputs should be a map of {'input_name': input_var}")
100+
101+
with self._prog_and_scope_guard():
102+
results = self.exe.run(feed=inputs,
103+
fetch_list=[self.predict_var.name],
104+
return_numpy=return_numpy)
105+
106+
return results
107+
108+
@contextlib.contextmanager
109+
def _prog_and_scope_guard(self):
110+
with framework.program_guard(main_program=self.inference_program):
111+
with executor.scope_guard(self.scope):
112+
yield

0 commit comments

Comments
 (0)