19
19
import data_feeder
20
20
import contextlib
21
21
import io
22
- import transpiler
22
+ import unique_name
23
23
24
24
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
25
25
import optimizer as opt_module
@@ -56,26 +56,62 @@ def __init__(self, epoch_id, step_id):
56
56
self .step = step_id
57
57
58
58
59
+ def check_and_get_place (place ):
60
+ """
61
+ Check the type of place or get the default place
62
+ Args:
63
+ place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
64
+
65
+ Raises:
66
+ TypeError if the type mismatched.
67
+
68
+ Returns:
69
+ the original place if it is not None.
70
+ if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
71
+ Otherwise returns CPUPlace by default.
72
+ """
73
+ if place is None :
74
+ if core .is_compiled_with_cuda ():
75
+ return core .CUDAPlace (0 )
76
+ else :
77
+ return core .CPUPlace ()
78
+ else :
79
+ if not isinstance (place , core .CUDAPlace ) and not isinstance (
80
+ place , core .CPUPlace ):
81
+ raise TypeError ("Place should be either CUDAPlace or CPUPlace" )
82
+ return place
83
+
84
+
59
85
class Trainer (object ):
60
86
"""
61
87
62
88
Args:
63
- program_func(callable): A function which will return loss. The loss must be a scaler.
89
+ train_func(callable): A function which will return loss. The loss must be a scalar.
90
+ infer_func(callable): A function which will return predict, used to save inference model
64
91
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
65
92
place: The device place of this trainer.
66
93
"""
67
94
68
- def __init__ (self , program_func , optimizer , param_path = None , place = None ):
95
+ def __init__ (self ,
96
+ train_func ,
97
+ infer_func ,
98
+ optimizer ,
99
+ param_path = None ,
100
+ place = None ):
69
101
# 1. we need to generate a framework.Program by calling
70
102
# program_func. Reference: fluid.program_guard in
71
103
# test_word2vec.py
104
+ if not isinstance (optimizer , opt_module .Optimizer ):
105
+ raise TypeError ("The optimizer should be an instance of Optimizer" )
106
+
107
+ self .infer_func = infer_func
72
108
self .scope = core .Scope ()
73
109
74
110
self .startup_program = framework .Program ()
75
111
self .train_program = framework .Program ()
76
112
77
113
with framework .program_guard (self .train_program , self .startup_program ):
78
- program_func_outs = program_func ()
114
+ program_func_outs = train_func ()
79
115
self .test_outputs = program_func_outs if isinstance (
80
116
program_func_outs , list ) else [program_func_outs ]
81
117
self .test_program = self .train_program .clone ()
@@ -86,9 +122,9 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
86
122
loss = self .test_outputs [0 ]
87
123
optimize_ops , params_grads = optimizer .minimize (loss )
88
124
89
- self .place = Trainer . _check_and_get_place (place )
125
+ self .place = check_and_get_place (place )
90
126
91
- self .dist_transpile_if_necessary (optimize_ops , params_grads )
127
+ self ._dist_transpile_if_necessary (optimize_ops , params_grads )
92
128
93
129
# 2. move the default_main_program to self.program and run the
94
130
# default_startup program on an empty core.Scope()
@@ -101,7 +137,7 @@ def __init__(self, program_func, optimizer, param_path=None, place=None):
101
137
# load params from param_path into scope
102
138
io .load_persistables (exe , dirname = param_path )
103
139
104
- def dist_transpile_if_necessary (self , optimize_ops , params_grads ):
140
+ def _dist_transpile_if_necessary (self , optimize_ops , params_grads ):
105
141
if "PADDLE_TRAINING_ROLE" not in os .environ :
106
142
return
107
143
@@ -190,31 +226,14 @@ def save_params(self, param_path):
190
226
exe = executor .Executor (self .place )
191
227
io .save_persistables (exe , dirname = param_path )
192
228
193
- @staticmethod
194
- def _check_and_get_place (place ):
195
- """
196
- Check the type of place or get the default place
197
- Args:
198
- place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
199
-
200
- Raises:
201
- TypeError if the type mismatched.
202
-
203
- Returns:
204
- the original place if it is not None.
205
- if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
206
- Otherwise returns CPUPlace by default.
207
- """
208
- if place is None :
209
- if core .is_compiled_with_cuda ():
210
- return core .CUDAPlace (0 )
211
- else :
212
- return core .CPUPlace ()
213
- else :
214
- if not isinstance (place , core .CUDAPlace ) and not isinstance (
215
- place , core .CPUPlace ):
216
- raise TypeError ("Place should be either CUDAPlace or CPUPlace" )
217
- return place
229
+ def save_inference_model (self , model_path ):
230
+ inference_program = framework .Program ()
231
+ with framework .program_guard (inference_program ):
232
+ with unique_name .guard ():
233
+ predict_var = self .infer_func ()
234
+ predict_var = self .train_program .block (0 ).var (predict_var .name )
235
+ exe = executor .Executor (self .place )
236
+ io .save_inference_model (model_path , [], [predict_var ], exe )
218
237
219
238
@contextlib .contextmanager
220
239
def _prog_and_scope_guard (self ):
0 commit comments