1
+ import random
1
2
import functools
2
- from py_paddle import swig_paddle as api
3
- from py_paddle import DataProviderConverter
3
+
4
4
from paddle .trainer_config_helpers import *
5
5
from paddle .trainer_config_helpers import inputs as ipts
6
- import random
6
+
7
+ from .base import *
8
+ from .. import DataProviderConverter
9
+ from .. import swig_paddle as api
7
10
8
11
__all__ = [
9
- 'RunnerChainItem ' , 'Runner' , 'DeviceChainItem ' , 'CreateGradientMachine' ,
12
+ 'RunnerItem ' , 'Runner' , 'DeviceItem ' , 'CreateGradientMachine' ,
10
13
'RandomInitializeParams' , 'BasicLocalParameterUpdater' , 'network' ,
11
14
'BasicTrainerDataProvider' , 'BasicDataProviderOps' ,
12
15
'BasicGradientMachineTrainOps' , 'Counter' , 'BatchEvaluate' ,
@@ -89,113 +92,9 @@ def __optimize_graph_func__():
89
92
return __impl__
90
93
91
94
92
- class RunnerChainItem (object ):
93
- def __init__ (self ):
94
- pass
95
-
96
- def initialize (self , context , next_callback ):
97
- next_callback (context )
98
-
99
- def finalize (self , context , next_callback ):
100
- next_callback (context )
101
-
102
- def on_pass_begin (self , context , next_callback ):
103
- next_callback (context )
104
-
105
- def on_pass_end (self , context , next_callback ):
106
- next_callback (context )
107
-
108
- def on_batch_begin (self , context , next_callback ):
109
- return next_callback (context )
110
-
111
- def on_batch_end (self , context , next_callback ):
112
- return next_callback (context )
113
-
114
-
115
- def default_next_callback (* args , ** kwargs ):
116
- return False
117
-
118
-
119
- class RunnerContext (object ):
120
- pass
121
-
122
-
123
- class Runner (object ):
124
- def __init__ (self ):
125
- self .chains = []
126
-
127
- self .begin_pass = None
128
- self .end_pass = None
129
- self .begin_batch = None
130
- self .end_batch = None
131
- self .finalize = None
132
-
133
- self .context = RunnerContext ()
134
- self .context .runner = self
135
-
136
- def add_chain_item (self , item ):
137
- assert isinstance (item , RunnerChainItem )
138
- self .chains .append (item )
139
-
140
- def initialize (self , parent = None ):
141
- if None not in [
142
- self .begin_pass , self .end_pass , self .begin_batch ,
143
- self .end_batch , self .finalize
144
- ]:
145
- return False
146
- else :
147
- assert len (self .chains ) != 0
148
- actual_init = default_next_callback
149
- self .begin_pass = default_next_callback
150
- self .end_pass = default_next_callback
151
- self .begin_batch = default_next_callback
152
- self .end_batch = default_next_callback
153
- self .finalize = default_next_callback
154
-
155
- for chain in reversed (self .chains ):
156
- assert isinstance (chain , RunnerChainItem )
157
- actual_init = functools .partial (
158
- chain .initialize , next_callback = actual_init )
159
- self .begin_pass = functools .partial (
160
- chain .on_pass_begin , next_callback = self .begin_pass )
161
- self .end_pass = functools .partial (
162
- chain .on_pass_end , next_callback = self .end_pass )
163
- self .begin_batch = functools .partial (
164
- chain .on_batch_begin , next_callback = self .begin_batch )
165
- self .end_batch = functools .partial (
166
- chain .on_batch_end , next_callback = self .end_batch )
167
- self .finalize = functools .partial (
168
- chain .finalize , next_callback = self .finalize )
169
-
170
- if parent is not None :
171
- self .context .parent = parent
172
-
173
- actual_init (self .context )
174
- return True
175
-
176
- def run_one_pass (self , parent = None ):
177
- if parent is not None :
178
- self .context .parent = parent
179
-
180
- self .begin_pass (self .context )
181
- exit_flag = False
182
- while not exit_flag :
183
- exit_flag = self .begin_batch (self .context )
184
- if exit_flag :
185
- break
186
- exit_flag = self .end_batch (self .context )
187
- self .end_pass (self .context )
188
-
189
- def __enter__ (self ):
190
- self .initialize ()
191
-
192
- def __exit__ (self , exc_type , exc_val , exc_tb ):
193
- self .finalize (self .context )
194
-
195
-
196
- class DeviceChainItem (RunnerChainItem ):
95
+ class DeviceItem (RunnerItem ):
197
96
def __init__ (self , use_gpu = False , device_count = 4 ):
198
- RunnerChainItem .__init__ (self )
97
+ RunnerItem .__init__ (self )
199
98
self .use_gpu = use_gpu
200
99
self .device_count = device_count
201
100
@@ -205,9 +104,9 @@ def initialize(self, context, next_callback):
205
104
next_callback (context )
206
105
207
106
208
- class CreateGradientMachine (RunnerChainItem ):
107
+ class CreateGradientMachine (RunnerItem ):
209
108
def __init__ (self , network ):
210
- RunnerChainItem .__init__ (self )
109
+ RunnerItem .__init__ (self )
211
110
assert isinstance (network , NetworkConfig )
212
111
self .__network__ = network
213
112
@@ -237,9 +136,9 @@ def finalize(self, context, next_callback):
237
136
next_callback (context )
238
137
239
138
240
- class RandomInitializeParams (RunnerChainItem ):
139
+ class RandomInitializeParams (RunnerItem ):
241
140
def __init__ (self ):
242
- RunnerChainItem .__init__ (self )
141
+ RunnerItem .__init__ (self )
243
142
244
143
def initialize (self , context , next_callback ):
245
144
assert hasattr (context , 'gradient_machine' ) and isinstance (
@@ -248,12 +147,12 @@ def initialize(self, context, next_callback):
248
147
next_callback (context )
249
148
250
149
251
- class BasicLocalParameterUpdaterOps (RunnerChainItem ):
150
+ class BasicLocalParameterUpdaterOps (RunnerItem ):
252
151
def __init__ (self ,
253
152
updater_name = 'updater' ,
254
153
batch_size_name = 'current_batch_size' ,
255
154
cost_name = 'current_cost' ):
256
- RunnerChainItem .__init__ (self )
155
+ RunnerItem .__init__ (self )
257
156
self .__updater_name__ = updater_name
258
157
self .__batch_size_name__ = batch_size_name
259
158
self .__cost_name__ = cost_name
@@ -311,9 +210,9 @@ def initialize(self, context, next_callback):
311
210
next_callback (context )
312
211
313
212
314
- class BasicGradientMachineTrainOps (RunnerChainItem ):
213
+ class BasicGradientMachineTrainOps (RunnerItem ):
315
214
def __init__ (self ):
316
- RunnerChainItem .__init__ (self )
215
+ RunnerItem .__init__ (self )
317
216
self .__out_args__ = api .Arguments .createArguments (0 )
318
217
319
218
def on_batch_begin (self , context , next_callback ):
@@ -334,9 +233,9 @@ def on_batch_begin(self, context, next_callback):
334
233
return next_callback (context )
335
234
336
235
337
- class Counter (RunnerChainItem ):
236
+ class Counter (RunnerItem ):
338
237
def __init__ (self ):
339
- RunnerChainItem .__init__ (self )
238
+ RunnerItem .__init__ (self )
340
239
341
240
def initialize (self , context , next_callback ):
342
241
context .current_pass_id = 0
@@ -353,9 +252,9 @@ def on_pass_end(self, context, next_callback):
353
252
context .current_pass_id += 1
354
253
355
254
356
- class BaseEvaluate (RunnerChainItem ):
255
+ class BaseEvaluate (RunnerItem ):
357
256
def __init__ (self , prefix = None ):
358
- RunnerChainItem .__init__ (self )
257
+ RunnerItem .__init__ (self )
359
258
self .__evaluator__ = None
360
259
if prefix is None :
361
260
prefix = ''
@@ -409,9 +308,9 @@ def on_pass_end(self, context, next_callback):
409
308
self .__evaluator__ .finish ()
410
309
411
310
412
- class BasicGradientMachineTestOps (RunnerChainItem ):
311
+ class BasicGradientMachineTestOps (RunnerItem ):
413
312
def __init__ (self ):
414
- RunnerChainItem .__init__ (self )
313
+ RunnerItem .__init__ (self )
415
314
self .__out_args__ = api .Arguments .createArguments (0 )
416
315
417
316
def on_pass_begin (self , context , next_callback ):
@@ -428,9 +327,9 @@ def on_pass_end(self, context, next_callback):
428
327
next_callback (context )
429
328
430
329
431
- class InheritGradientMachineUpdater (RunnerChainItem ):
330
+ class InheritGradientMachineUpdater (RunnerItem ):
432
331
def __init__ (self ):
433
- RunnerChainItem .__init__ (self )
332
+ RunnerItem .__init__ (self )
434
333
435
334
def initialize (self , context , next_callback ):
436
335
if context .parent is not None :
@@ -449,18 +348,18 @@ def on_batch_begin(self, context, next_callback):
449
348
return next_callback (context )
450
349
451
350
452
- class TestOnPassEnd (RunnerChainItem ):
351
+ class TestOnPassEnd (RunnerItem ):
453
352
def __init__ (self , ** kwargs ):
454
- RunnerChainItem .__init__ (self )
353
+ RunnerItem .__init__ (self )
455
354
self .__test_runner__ = Runner ()
456
- self .__test_runner__ .add_chain_item (InheritGradientMachineUpdater ())
457
- self .__test_runner__ .add_chain_item (BasicTestDataProvider (** kwargs ))
458
- self .__test_runner__ .add_chain_item (BasicGradientMachineTestOps ())
459
- self .__test_runner__ .add_chain_item (PassEvaluate (prefix = 'Test: ' ))
355
+ self .__test_runner__ .add_item (InheritGradientMachineUpdater ())
356
+ self .__test_runner__ .add_item (BasicTestDataProvider (** kwargs ))
357
+ self .__test_runner__ .add_item (BasicGradientMachineTestOps ())
358
+ self .__test_runner__ .add_item (PassEvaluate (prefix = 'Test: ' ))
460
359
461
360
def initialize (self , context , next_callback ):
462
361
next_callback (context )
463
- self .__test_runner__ .initialize (context )
362
+ self .__test_runner__ .__initialize__ (context )
464
363
465
364
def on_pass_end (self , context , next_callback ):
466
365
self .__test_runner__ .run_one_pass (parent = context )
@@ -515,9 +414,9 @@ def next(self):
515
414
raise StopIteration
516
415
517
416
518
- class BasicDataProviderOps (RunnerChainItem ):
417
+ class BasicDataProviderOps (RunnerItem ):
519
418
def __init__ (self , provider_name = 'data_provider' ):
520
- RunnerChainItem .__init__ (self )
419
+ RunnerItem .__init__ (self )
521
420
self .__provider_name__ = provider_name
522
421
523
422
def __get_provider__ (self , context ):
@@ -575,9 +474,9 @@ def initialize(self, context, next_callback):
575
474
BasicTestDataProvider = data_provider_creator (False )
576
475
577
476
578
- class SaveParamsOnPassEnd (RunnerChainItem ):
477
+ class SaveParamsOnPassEnd (RunnerItem ):
579
478
def __init__ (self ):
580
- RunnerChainItem .__init__ (self )
479
+ RunnerItem .__init__ (self )
581
480
582
481
def on_pass_end (self , context , next_callback ):
583
482
context .updater .catchUpWith ()
@@ -591,12 +490,12 @@ def on_pass_end(self, context, next_callback):
591
490
class RunnerBuilder (object ):
592
491
def __init__ (self , network , use_gpu = False , device_count = 1 ):
593
492
self .__runner__ = Runner ()
594
- self .__runner__ .add_chain_item (Counter ())
493
+ self .__runner__ .add_item (Counter ())
595
494
self .__network__ = network
596
- self .__runner__ .add_chain_item (
597
- DeviceChainItem (
495
+ self .__runner__ .add_item (
496
+ DeviceItem (
598
497
use_gpu = use_gpu , device_count = device_count ))
599
- self .__runner__ .add_chain_item (
498
+ self .__runner__ .add_item (
600
499
CreateGradientMachine (network = self .__network__ ))
601
500
602
501
self .__train_data__ = None
@@ -605,7 +504,7 @@ def __init__(self, network, use_gpu=False, device_count=1):
605
504
self .__evaluate__ = []
606
505
607
506
def with_std_random_init_params (self ):
608
- self .__runner__ .add_chain_item (RandomInitializeParams ())
507
+ self .__runner__ .add_item (RandomInitializeParams ())
609
508
return self
610
509
611
510
def with_train_data (self , method , file_list , batch_size = None , ** kwargs ):
@@ -659,9 +558,9 @@ def with_std_local_trainer(self, **kwargs):
659
558
).with_batch_evaluator ().with_std_param_saver ()
660
559
661
560
def build (self ):
662
- self .__runner__ .add_chain_item (self .__train_data__ )
663
- self .__runner__ .add_chain_item (self .__updater__ )
664
- self .__runner__ .add_chain_item (self .__gradient_machine__ )
561
+ self .__runner__ .add_item (self .__train_data__ )
562
+ self .__runner__ .add_item (self .__updater__ )
563
+ self .__runner__ .add_item (self .__gradient_machine__ )
665
564
for each in self .__evaluate__ :
666
- self .__runner__ .add_chain_item (each )
565
+ self .__runner__ .add_item (each )
667
566
return self .__runner__
0 commit comments