Skip to content

Commit e522266

Browse files
committed
Move trainer to py_paddle.trainer
1 parent 704ed1e commit e522266

File tree

4 files changed

+10
-17
lines changed

4 files changed

+10
-17
lines changed

demo/mnist/api_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from paddle.trainer_config_helpers import *
1212

1313
import mnist_provider
14-
from trainer import *
14+
from py_paddle.trainer import *
1515

1616

1717
@network(
@@ -40,7 +40,7 @@ def main():
4040
file_list=['./data/raw_data/train']).with_std_tester(
4141
method=mnist_provider.process,
4242
file_list=['./data/raw_data/t10k']).build()
43-
with runner.use():
43+
with runner:
4444
for _ in xrange(2):
4545
runner.run_one_pass()
4646

paddle/py_paddle/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'paddle',
2020
'DataProviderConverter',
2121
'DataProviderWrapperConverter', # for deprecated usage.
22-
'loadParameterFile'
22+
'loadParameterFile',
23+
'trainer'
2324
]
2425
util.monkeypatches()

demo/mnist/trainer.py renamed to paddle/py_paddle/trainer/__init__.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,6 @@ class RunnerContext(object):
120120
pass
121121

122122

123-
class RunnerSection(object):
124-
def __init__(self, runner):
125-
self.runner = runner
126-
127-
def __enter__(self):
128-
self.runner.initialize()
129-
130-
def __exit__(self, exc_type, exc_val, exc_tb):
131-
self.runner.finalize(self.runner.context)
132-
133-
134123
class Runner(object):
135124
def __init__(self):
136125
self.chains = []
@@ -197,8 +186,11 @@ def run_one_pass(self, parent=None):
197186
exit_flag = self.end_batch(self.context)
198187
self.end_pass(self.context)
199188

200-
def use(self):
201-
return RunnerSection(self)
189+
def __enter__(self):
190+
self.initialize()
191+
192+
def __exit__(self, exc_type, exc_val, exc_tb):
193+
self.finalize(self.context)
202194

203195

204196
class DeviceChainItem(RunnerChainItem):

paddle/setup.py.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ setup(name="py_paddle",
6464
extra_compile_args = extra_comps
6565
)
6666
],
67-
packages=['py_paddle'],
67+
packages=['py_paddle', 'py_paddle.trainer'],
6868
include_dirs = include_dirs,
6969
install_requires = [
7070
'numpy>=1.8.0', # The numpy is required.

0 commit comments

Comments
 (0)