Skip to content

Commit 1f4820d

Browse files
committed
Extract base.py
1 parent e522266 commit 1f4820d

File tree

2 files changed

+300
-146
lines changed

2 files changed

+300
-146
lines changed

paddle/py_paddle/trainer/__init__.py

Lines changed: 45 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import random
12
import functools
2-
from py_paddle import swig_paddle as api
3-
from py_paddle import DataProviderConverter
3+
44
from paddle.trainer_config_helpers import *
55
from paddle.trainer_config_helpers import inputs as ipts
6-
import random
6+
7+
from .base import *
8+
from .. import DataProviderConverter
9+
from .. import swig_paddle as api
710

811
__all__ = [
9-
'RunnerChainItem', 'Runner', 'DeviceChainItem', 'CreateGradientMachine',
12+
'RunnerItem', 'Runner', 'DeviceItem', 'CreateGradientMachine',
1013
'RandomInitializeParams', 'BasicLocalParameterUpdater', 'network',
1114
'BasicTrainerDataProvider', 'BasicDataProviderOps',
1215
'BasicGradientMachineTrainOps', 'Counter', 'BatchEvaluate',
@@ -89,113 +92,9 @@ def __optimize_graph_func__():
8992
return __impl__
9093

9194

92-
class RunnerChainItem(object):
93-
def __init__(self):
94-
pass
95-
96-
def initialize(self, context, next_callback):
97-
next_callback(context)
98-
99-
def finalize(self, context, next_callback):
100-
next_callback(context)
101-
102-
def on_pass_begin(self, context, next_callback):
103-
next_callback(context)
104-
105-
def on_pass_end(self, context, next_callback):
106-
next_callback(context)
107-
108-
def on_batch_begin(self, context, next_callback):
109-
return next_callback(context)
110-
111-
def on_batch_end(self, context, next_callback):
112-
return next_callback(context)
113-
114-
115-
def default_next_callback(*args, **kwargs):
116-
return False
117-
118-
119-
class RunnerContext(object):
120-
pass
121-
122-
123-
class Runner(object):
124-
def __init__(self):
125-
self.chains = []
126-
127-
self.begin_pass = None
128-
self.end_pass = None
129-
self.begin_batch = None
130-
self.end_batch = None
131-
self.finalize = None
132-
133-
self.context = RunnerContext()
134-
self.context.runner = self
135-
136-
def add_chain_item(self, item):
137-
assert isinstance(item, RunnerChainItem)
138-
self.chains.append(item)
139-
140-
def initialize(self, parent=None):
141-
if None not in [
142-
self.begin_pass, self.end_pass, self.begin_batch,
143-
self.end_batch, self.finalize
144-
]:
145-
return False
146-
else:
147-
assert len(self.chains) != 0
148-
actual_init = default_next_callback
149-
self.begin_pass = default_next_callback
150-
self.end_pass = default_next_callback
151-
self.begin_batch = default_next_callback
152-
self.end_batch = default_next_callback
153-
self.finalize = default_next_callback
154-
155-
for chain in reversed(self.chains):
156-
assert isinstance(chain, RunnerChainItem)
157-
actual_init = functools.partial(
158-
chain.initialize, next_callback=actual_init)
159-
self.begin_pass = functools.partial(
160-
chain.on_pass_begin, next_callback=self.begin_pass)
161-
self.end_pass = functools.partial(
162-
chain.on_pass_end, next_callback=self.end_pass)
163-
self.begin_batch = functools.partial(
164-
chain.on_batch_begin, next_callback=self.begin_batch)
165-
self.end_batch = functools.partial(
166-
chain.on_batch_end, next_callback=self.end_batch)
167-
self.finalize = functools.partial(
168-
chain.finalize, next_callback=self.finalize)
169-
170-
if parent is not None:
171-
self.context.parent = parent
172-
173-
actual_init(self.context)
174-
return True
175-
176-
def run_one_pass(self, parent=None):
177-
if parent is not None:
178-
self.context.parent = parent
179-
180-
self.begin_pass(self.context)
181-
exit_flag = False
182-
while not exit_flag:
183-
exit_flag = self.begin_batch(self.context)
184-
if exit_flag:
185-
break
186-
exit_flag = self.end_batch(self.context)
187-
self.end_pass(self.context)
188-
189-
def __enter__(self):
190-
self.initialize()
191-
192-
def __exit__(self, exc_type, exc_val, exc_tb):
193-
self.finalize(self.context)
194-
195-
196-
class DeviceChainItem(RunnerChainItem):
95+
class DeviceItem(RunnerItem):
19796
def __init__(self, use_gpu=False, device_count=4):
198-
RunnerChainItem.__init__(self)
97+
RunnerItem.__init__(self)
19998
self.use_gpu = use_gpu
20099
self.device_count = device_count
201100

@@ -205,9 +104,9 @@ def initialize(self, context, next_callback):
205104
next_callback(context)
206105

207106

208-
class CreateGradientMachine(RunnerChainItem):
107+
class CreateGradientMachine(RunnerItem):
209108
def __init__(self, network):
210-
RunnerChainItem.__init__(self)
109+
RunnerItem.__init__(self)
211110
assert isinstance(network, NetworkConfig)
212111
self.__network__ = network
213112

@@ -237,9 +136,9 @@ def finalize(self, context, next_callback):
237136
next_callback(context)
238137

239138

240-
class RandomInitializeParams(RunnerChainItem):
139+
class RandomInitializeParams(RunnerItem):
241140
def __init__(self):
242-
RunnerChainItem.__init__(self)
141+
RunnerItem.__init__(self)
243142

244143
def initialize(self, context, next_callback):
245144
assert hasattr(context, 'gradient_machine') and isinstance(
@@ -248,12 +147,12 @@ def initialize(self, context, next_callback):
248147
next_callback(context)
249148

250149

251-
class BasicLocalParameterUpdaterOps(RunnerChainItem):
150+
class BasicLocalParameterUpdaterOps(RunnerItem):
252151
def __init__(self,
253152
updater_name='updater',
254153
batch_size_name='current_batch_size',
255154
cost_name='current_cost'):
256-
RunnerChainItem.__init__(self)
155+
RunnerItem.__init__(self)
257156
self.__updater_name__ = updater_name
258157
self.__batch_size_name__ = batch_size_name
259158
self.__cost_name__ = cost_name
@@ -311,9 +210,9 @@ def initialize(self, context, next_callback):
311210
next_callback(context)
312211

313212

314-
class BasicGradientMachineTrainOps(RunnerChainItem):
213+
class BasicGradientMachineTrainOps(RunnerItem):
315214
def __init__(self):
316-
RunnerChainItem.__init__(self)
215+
RunnerItem.__init__(self)
317216
self.__out_args__ = api.Arguments.createArguments(0)
318217

319218
def on_batch_begin(self, context, next_callback):
@@ -334,9 +233,9 @@ def on_batch_begin(self, context, next_callback):
334233
return next_callback(context)
335234

336235

337-
class Counter(RunnerChainItem):
236+
class Counter(RunnerItem):
338237
def __init__(self):
339-
RunnerChainItem.__init__(self)
238+
RunnerItem.__init__(self)
340239

341240
def initialize(self, context, next_callback):
342241
context.current_pass_id = 0
@@ -353,9 +252,9 @@ def on_pass_end(self, context, next_callback):
353252
context.current_pass_id += 1
354253

355254

356-
class BaseEvaluate(RunnerChainItem):
255+
class BaseEvaluate(RunnerItem):
357256
def __init__(self, prefix=None):
358-
RunnerChainItem.__init__(self)
257+
RunnerItem.__init__(self)
359258
self.__evaluator__ = None
360259
if prefix is None:
361260
prefix = ''
@@ -409,9 +308,9 @@ def on_pass_end(self, context, next_callback):
409308
self.__evaluator__.finish()
410309

411310

412-
class BasicGradientMachineTestOps(RunnerChainItem):
311+
class BasicGradientMachineTestOps(RunnerItem):
413312
def __init__(self):
414-
RunnerChainItem.__init__(self)
313+
RunnerItem.__init__(self)
415314
self.__out_args__ = api.Arguments.createArguments(0)
416315

417316
def on_pass_begin(self, context, next_callback):
@@ -428,9 +327,9 @@ def on_pass_end(self, context, next_callback):
428327
next_callback(context)
429328

430329

431-
class InheritGradientMachineUpdater(RunnerChainItem):
330+
class InheritGradientMachineUpdater(RunnerItem):
432331
def __init__(self):
433-
RunnerChainItem.__init__(self)
332+
RunnerItem.__init__(self)
434333

435334
def initialize(self, context, next_callback):
436335
if context.parent is not None:
@@ -449,18 +348,18 @@ def on_batch_begin(self, context, next_callback):
449348
return next_callback(context)
450349

451350

452-
class TestOnPassEnd(RunnerChainItem):
351+
class TestOnPassEnd(RunnerItem):
453352
def __init__(self, **kwargs):
454-
RunnerChainItem.__init__(self)
353+
RunnerItem.__init__(self)
455354
self.__test_runner__ = Runner()
456-
self.__test_runner__.add_chain_item(InheritGradientMachineUpdater())
457-
self.__test_runner__.add_chain_item(BasicTestDataProvider(**kwargs))
458-
self.__test_runner__.add_chain_item(BasicGradientMachineTestOps())
459-
self.__test_runner__.add_chain_item(PassEvaluate(prefix='Test: '))
355+
self.__test_runner__.add_item(InheritGradientMachineUpdater())
356+
self.__test_runner__.add_item(BasicTestDataProvider(**kwargs))
357+
self.__test_runner__.add_item(BasicGradientMachineTestOps())
358+
self.__test_runner__.add_item(PassEvaluate(prefix='Test: '))
460359

461360
def initialize(self, context, next_callback):
462361
next_callback(context)
463-
self.__test_runner__.initialize(context)
362+
self.__test_runner__.__initialize__(context)
464363

465364
def on_pass_end(self, context, next_callback):
466365
self.__test_runner__.run_one_pass(parent=context)
@@ -515,9 +414,9 @@ def next(self):
515414
raise StopIteration
516415

517416

518-
class BasicDataProviderOps(RunnerChainItem):
417+
class BasicDataProviderOps(RunnerItem):
519418
def __init__(self, provider_name='data_provider'):
520-
RunnerChainItem.__init__(self)
419+
RunnerItem.__init__(self)
521420
self.__provider_name__ = provider_name
522421

523422
def __get_provider__(self, context):
@@ -575,9 +474,9 @@ def initialize(self, context, next_callback):
575474
BasicTestDataProvider = data_provider_creator(False)
576475

577476

578-
class SaveParamsOnPassEnd(RunnerChainItem):
477+
class SaveParamsOnPassEnd(RunnerItem):
579478
def __init__(self):
580-
RunnerChainItem.__init__(self)
479+
RunnerItem.__init__(self)
581480

582481
def on_pass_end(self, context, next_callback):
583482
context.updater.catchUpWith()
@@ -591,12 +490,12 @@ def on_pass_end(self, context, next_callback):
591490
class RunnerBuilder(object):
592491
def __init__(self, network, use_gpu=False, device_count=1):
593492
self.__runner__ = Runner()
594-
self.__runner__.add_chain_item(Counter())
493+
self.__runner__.add_item(Counter())
595494
self.__network__ = network
596-
self.__runner__.add_chain_item(
597-
DeviceChainItem(
495+
self.__runner__.add_item(
496+
DeviceItem(
598497
use_gpu=use_gpu, device_count=device_count))
599-
self.__runner__.add_chain_item(
498+
self.__runner__.add_item(
600499
CreateGradientMachine(network=self.__network__))
601500

602501
self.__train_data__ = None
@@ -605,7 +504,7 @@ def __init__(self, network, use_gpu=False, device_count=1):
605504
self.__evaluate__ = []
606505

607506
def with_std_random_init_params(self):
608-
self.__runner__.add_chain_item(RandomInitializeParams())
507+
self.__runner__.add_item(RandomInitializeParams())
609508
return self
610509

611510
def with_train_data(self, method, file_list, batch_size=None, **kwargs):
@@ -659,9 +558,9 @@ def with_std_local_trainer(self, **kwargs):
659558
).with_batch_evaluator().with_std_param_saver()
660559

661560
def build(self):
662-
self.__runner__.add_chain_item(self.__train_data__)
663-
self.__runner__.add_chain_item(self.__updater__)
664-
self.__runner__.add_chain_item(self.__gradient_machine__)
561+
self.__runner__.add_item(self.__train_data__)
562+
self.__runner__.add_item(self.__updater__)
563+
self.__runner__.add_item(self.__gradient_machine__)
665564
for each in self.__evaluate__:
666-
self.__runner__.add_chain_item(each)
565+
self.__runner__.add_item(each)
667566
return self.__runner__

0 commit comments

Comments
 (0)