Skip to content

Commit 2fdadf5

Browse files
committed
Add introduction ipynb
1 parent 3ceee61 commit 2fdadf5

File tree

7 files changed

+262
-31
lines changed

7 files changed

+262
-31
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Makefile
1212

1313
*~
1414
bazel-*
15+
.ipynb_checkpoints

demo/introduction/linear.ipynb

Lines changed: 210 additions & 0 deletions
Large diffs are not rendered by default.

paddle/py_paddle/trainer/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def run_one_pass(self):
224224
break
225225
exit_flag = self.__end_batch__()
226226
self.__end_pass__()
227+
return self.__context__
227228

228229
def __enter__(self):
229230
"""

paddle/py_paddle/trainer/builder.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@ def __init__(self, network, use_gpu=False, device_count=1):
1616
self.__train_data__ = None
1717
self.__updater__ = None
1818
self.__gradient_machine__ = None
19+
self.__init_param__ = None
1920
self.__evaluate__ = []
2021

2122
def with_std_random_init_params(self):
22-
self.__runner__.add_item(std_random_init_params())
23+
self.__init_param__ = std_random_init_params()
2324
return self
2425

25-
def with_train_data(self, method, file_list, batch_size=None, **kwargs):
26+
def with_train_data(self, method, file_list=None, batch_size=None,
27+
**kwargs):
2628
if batch_size is None:
2729
batch_size = self.__network__.optimize_graph().batch_size
2830

29-
self.__train_data__ = BasicTrainerDataProvider(
31+
if file_list is None:
32+
file_list = [None]
33+
34+
self.__train_data__ = BasicPaddleTrainerDataProvider(
3035
network=self.__network__,
3136
method=method,
3237
file_list=file_list,
@@ -77,8 +82,15 @@ def with_observer(self, on_batch_end=None, on_pass_end=None):
7782
return self
7883

7984
def build(self):
85+
if self.__init_param__ is None:
86+
self.with_std_random_init_params()
87+
self.__runner__.add_item(self.__init_param__)
8088
self.__runner__.add_item(self.__train_data__)
89+
if self.__updater__ is None:
90+
self.with_std_local_updater()
8191
self.__runner__.add_item(self.__updater__)
92+
if self.__gradient_machine__ is None:
93+
self.with_std_gradient_machine_ops()
8294
self.__runner__.add_item(self.__gradient_machine__)
8395
for each in self.__evaluate__:
8496
self.__runner__.add_item(each)
Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,41 @@
11
from .. import DataProviderConverter
22
import random
33

4-
__all__ = ['DataProvider', 'NaiveDataProvider']
4+
__all__ = ['DataProvider', 'NaiveMemPooledDataProvider', 'NaiveDataProvider']
55

66

77
class DataProvider(object):
88
__slots__ = [
9-
'__init__', 'reset', 'next', '__provider__', '__converter__',
9+
'__init__', 'reset', 'next', '__method__', '__converter__',
1010
'__batch_size__', '__should_shuffle__'
1111
]
1212

13-
def __init__(self, provider, input_types, batch_size, should_shuffle=True):
14-
self.__provider__ = provider
13+
def __init__(self, method, input_types, batch_size, should_shuffle=True):
14+
self.__method__ = method
1515
self.__converter__ = DataProviderConverter(input_types)
1616
self.__batch_size__ = batch_size
17-
if self.__provider__.should_shuffle is None:
18-
self.__provider__.should_shuffle = should_shuffle
17+
self.__should_shuffle__ = should_shuffle
1918

2019
def reset(self):
2120
raise NotImplemented()
2221

2322
def next(self):
2423
raise NotImplemented()
2524

26-
def __should_shuffle__(self):
27-
return self.__provider__.should_shuffle
28-
2925

30-
class NaiveDataProvider(DataProvider):
31-
def __init__(self, provider, input_types, batch_size, should_shuffle=True):
32-
super(NaiveDataProvider, self).__init__(
33-
provider=provider,
26+
class NaiveMemPooledDataProvider(DataProvider):
27+
def __init__(self, method, input_types, batch_size, should_shuffle):
28+
super(NaiveMemPooledDataProvider, self).__init__(
29+
method=method,
3430
input_types=input_types,
3531
batch_size=batch_size,
3632
should_shuffle=should_shuffle)
3733
self.__pool__ = []
3834
self.__idx__ = 0
3935

4036
def reset(self):
41-
def __to_pool__():
42-
for filename in self.__provider__.file_list:
43-
for item in self.__provider__.generator(self.__provider__,
44-
filename):
45-
yield item
46-
47-
self.__pool__ = list(__to_pool__())
48-
if self.__should_shuffle__():
37+
self.__pool__ = list(self.__method__())
38+
if self.__should_shuffle__:
4939
random.shuffle(self.__pool__)
5040

5141
self.__idx__ = 0
@@ -58,3 +48,17 @@ def next(self):
5848
return self.__converter__(self.__pool__[begin:end]), end - begin
5949
else:
6050
raise StopIteration
51+
52+
53+
class NaiveDataProvider(NaiveMemPooledDataProvider):
54+
def __init__(self, provider, input_types, batch_size, should_shuffle=True):
55+
def __to_pool__():
56+
for filename in provider.file_list:
57+
for item in provider.generator(provider, filename):
58+
yield item
59+
60+
super(NaiveDataProvider, self).__init__(
61+
method=__to_pool__,
62+
input_types=input_types,
63+
batch_size=batch_size,
64+
should_shuffle=should_shuffle)

paddle/py_paddle/trainer/items.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
'BasicLocalParameterUpdaterOps', 'BasicLocalParameterUpdater', 'Counter',
1010
'BatchEvaluate', 'PassEvaluate', 'BasicGradientMachineTrainOps',
1111
'BasicGradientMachineTestOps', 'InheritGradientMachineUpdater',
12-
'TestOnPassEnd', 'BasicTrainerDataProvider', 'BasicDataProviderOps',
13-
'BasicTestDataProvider', 'SaveParamsOnPassEnd', 'BaseObserveItem'
12+
'TestOnPassEnd', 'BasicPaddleTrainerDataProvider', 'BasicDataProviderOps',
13+
'BasicPaddleTestDataProvider', 'SaveParamsOnPassEnd', 'BaseObserveItem'
1414
]
1515

1616

@@ -330,8 +330,7 @@ def on_batch_begin(self, next_callback):
330330
def data_provider_creator(is_train):
331331
class __cls__(BasicDataProviderOps):
332332
def __init__(self, network, method, file_list, batch_size, **kwargs):
333-
BasicDataProviderOps.__init__(self)
334-
assert isinstance(network, NetworkConfig)
333+
super(__cls__, self).__init__()
335334
self.__dataprovider__ = method(
336335
file_list=file_list,
337336
input_order=network.input_order(),
@@ -355,16 +354,16 @@ def initialize(self, context, next_callback):
355354
return __cls__
356355

357356

358-
BasicTrainerDataProvider = data_provider_creator(True)
359-
BasicTestDataProvider = data_provider_creator(False)
357+
BasicPaddleTrainerDataProvider = data_provider_creator(True)
358+
BasicPaddleTestDataProvider = data_provider_creator(False)
360359

361360

362361
class TestOnPassEnd(RunnerItem):
363362
def __init__(self, **kwargs):
364363
RunnerItem.__init__(self)
365364
self.__test_runner__ = Runner()
366365
self.__test_runner__.add_item(InheritGradientMachineUpdater())
367-
self.__test_runner__.add_item(BasicTestDataProvider(**kwargs))
366+
self.__test_runner__.add_item(BasicPaddleTestDataProvider(**kwargs))
368367
self.__test_runner__.add_item(BasicGradientMachineTestOps())
369368
self.__test_runner__.add_item(PassEvaluate(prefix='Test: '))
370369

paddle/py_paddle/trainer/network.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from paddle.trainer_config_helpers import *
22
from paddle.trainer_config_helpers import inputs as ipts
3+
import paddle.trainer.PyDataProvider2 as dp2
34

45
__all__ = ['NetworkConfig', 'network']
56

@@ -52,6 +53,9 @@ def optimize_graph(self):
5253
"""
5354
raise NotImplemented()
5455

56+
def provider(self, **kwargs):
57+
return dp2.provider(input_types=self.input_types(), **kwargs)
58+
5559

5660
def network(inputs, **opt_kwargs):
5761
"""

0 commit comments

Comments
 (0)