@@ -42,25 +42,35 @@ def train(self, reader, topology, parameters, event_handler=None):
42
42
43
43
44
44
class SGD (ITrainer ):
45
- def __init__ (self , update_equation ):
45
+ def __init__ (self , cost , parameters , update_equation ):
46
46
"""
47
47
Simple SGD Trainer.
48
48
49
49
:param update_equation: The optimizer object.
50
50
:type update_equation: v2_optimizer.Optimizer
51
51
"""
52
+
53
+ if not isinstance (parameters , v2_parameters .Parameters ):
54
+ raise TypeError ('parameters should be parameters' )
55
+
52
56
if not isinstance (update_equation , v2_optimizer .Optimizer ):
53
- raise ValueError ("update equation parameter must be "
54
- "paddle.v2.optimizer.Optimizer" )
57
+ raise TypeError ("update equation parameter must be "
58
+ "paddle.v2.optimizer.Optimizer" )
59
+ topology = Topology (cost )
55
60
self .__optimizer__ = update_equation
61
+ self .__topology__ = topology
62
+ self .__parameters__ = parameters
63
+ self .__topology_in_proto__ = topology .proto ()
64
+ self .__data_types__ = topology .data_type ()
65
+ gm = api .GradientMachine .createFromConfigProto (
66
+ self .__topology_in_proto__ , api .CREATE_MODE_NORMAL ,
67
+ self .__optimizer__ .enable_types ())
68
+ assert isinstance (gm , api .GradientMachine )
69
+ parameters .append_gradient_machine (gm )
70
+ self .__gradient_machine__ = gm
71
+ self .__gradient_machine__ .randParameters ()
56
72
57
- def train (self ,
58
- reader ,
59
- cost ,
60
- parameters ,
61
- num_passes = 1 ,
62
- event_handler = None ,
63
- reader_dict = None ):
73
+ def train (self , reader , num_passes = 1 , event_handler = None , reader_dict = None ):
64
74
"""
65
75
Training method. Will train num_passes of input data.
66
76
@@ -76,44 +86,41 @@ def train(self,
76
86
if event_handler is None :
77
87
event_handler = default_event_handler
78
88
79
- topology = Topology (cost )
89
+ if reader_dict is None :
90
+ reader_dict = self .default_reader_dict ()
80
91
81
92
__check_train_args__ (** locals ())
82
93
83
- gm = api .GradientMachine .createFromConfigProto (
84
- topology .proto (), api .CREATE_MODE_NORMAL ,
85
- self .__optimizer__ .enable_types ())
86
- assert isinstance (gm , api .GradientMachine )
87
- parameters .append_gradient_machine (gm )
88
- gm .randParameters ()
89
94
updater = self .__optimizer__ .create_local_updater ()
90
- updater .init (gm )
95
+ updater .init (self . __gradient_machine__ )
91
96
92
- gm .start ()
93
- batch_evaluator = gm .makeEvaluator ()
97
+ self . __gradient_machine__ .start ()
98
+ batch_evaluator = self . __gradient_machine__ .makeEvaluator ()
94
99
assert isinstance (batch_evaluator , api .Evaluator )
95
- pass_evaluator = gm .makeEvaluator ()
100
+ pass_evaluator = self . __gradient_machine__ .makeEvaluator ()
96
101
assert isinstance (pass_evaluator , api .Evaluator )
97
102
out_args = api .Arguments .createArguments (0 )
98
103
99
- feeder = DataFeeder (topology . data_type () , reader_dict )
104
+ feeder = DataFeeder (self . __data_types__ , reader_dict )
100
105
101
106
for pass_id in xrange (num_passes ):
102
107
event_handler (v2_event .BeginPass (pass_id ))
103
108
pass_evaluator .start ()
104
109
updater .startPass ()
105
110
for batch_id , data_batch in enumerate (reader ()):
106
111
pass_type = updater .startBatch (len (data_batch ))
107
- gm .forwardBackward (feeder (data_batch ), out_args , pass_type )
112
+ self .__gradient_machine__ .forwardBackward (
113
+ feeder (data_batch ), out_args , pass_type )
108
114
batch_evaluator .start ()
109
115
event_handler (
110
116
v2_event .BeginIteration (
111
117
pass_id = pass_id , batch_id = batch_id ))
112
118
pass_type = updater .startBatch (len (data_batch ))
113
- gm .forwardBackward (feeder (data_batch ), out_args , pass_type )
114
- gm .eval (pass_evaluator )
115
- gm .eval (batch_evaluator )
116
- for each_param in gm .getParameters ():
119
+ self .__gradient_machine__ .forwardBackward (
120
+ feeder (data_batch ), out_args , pass_type )
121
+ self .__gradient_machine__ .eval (pass_evaluator )
122
+ self .__gradient_machine__ .eval (batch_evaluator )
123
+ for each_param in self .__gradient_machine__ .getParameters ():
117
124
updater .update (each_param )
118
125
# Get cost. We use numpy to calculate total cost for this batch.
119
126
cost_vec = out_args .getSlotValue (0 )
@@ -131,22 +138,37 @@ def train(self,
131
138
updater .finishPass ()
132
139
pass_evaluator .finish ()
133
140
event_handler (v2_event .EndPass (pass_id , evaluator = pass_evaluator ))
134
- gm .finish ()
141
+ self .__gradient_machine__ .finish ()
142
+
143
+ def default_reader_dict (self ):
144
+ reader_dict = dict ()
145
+ for i , tp in enumerate (self .__data_types__ ):
146
+ reader_dict [tp [0 ]] = i
147
+ return reader_dict
148
+
149
+ def test (self , reader , reader_dict = None ):
150
+ if reader_dict is None :
151
+ reader_dict = self .default_reader_dict ()
152
+
153
+ feeder = DataFeeder (self .__data_types__ , reader_dict )
154
+ evaluator = self .__gradient_machine__ .makeEvaluator ()
155
+ out_args = api .Arguments .createArguments (0 )
156
+ evaluator .start ()
157
+ for data_batch in reader ():
158
+ self .__gradient_machine__ .forward (
159
+ feeder (data_batch ), out_args , api .PASS_TEST )
160
+ self .__gradient_machine__ .eval (evaluator )
161
+
162
+ evaluator .finish ()
163
+ return v2_event .TestResult (evaluator = evaluator )
135
164
136
165
137
- def __check_train_args__ (reader , topology , parameters , event_handler , ** kwargs ):
166
+ def __check_train_args__ (reader , event_handler , ** kwargs ):
138
167
"""
139
168
Check train function's argument types
140
169
"""
141
170
if not callable (reader ) or not isinstance (reader (), collections .Iterator ):
142
171
raise TypeError ('train_data_reader should be a function, '
143
172
'which can return a iterator' )
144
-
145
- if not isinstance (topology , Topology ):
146
- raise TypeError ('topology should be a model config' )
147
-
148
- if not isinstance (parameters , v2_parameters .Parameters ):
149
- raise TypeError ('parameters should be a parameter pool' )
150
-
151
173
if not callable (event_handler ):
152
174
raise TypeError ('event handler should be a function' )
0 commit comments