@@ -61,7 +61,7 @@ def __init__(self, cost, parameters, update_equation):
61
61
self .__gradient_machine__ .randParameters ()
62
62
parameters .append_gradient_machine (gm )
63
63
64
- def train (self , reader , num_passes = 1 , event_handler = None , reader_dict = None ):
64
+ def train (self , reader , num_passes = 1 , event_handler = None , feeding = None ):
65
65
"""
66
66
Training method. Will train num_passes of input data.
67
67
@@ -70,14 +70,13 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
70
70
:param event_handler: Event handler. A method will be invoked when event
71
71
occurred.
72
72
:type event_handler: (BaseEvent) => None
73
+ :param feeding: Feeding is a map of neural network input name and array
74
+ index that reader returns.
75
+ :type feeding: dict
73
76
:return:
74
77
"""
75
78
if event_handler is None :
76
79
event_handler = default_event_handler
77
-
78
- if reader_dict is None :
79
- reader_dict = self .default_reader_dict ()
80
-
81
80
__check_train_args__ (** locals ())
82
81
83
82
updater = self .__optimizer__ .create_local_updater ()
@@ -89,9 +88,7 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
89
88
pass_evaluator = self .__gradient_machine__ .makeEvaluator ()
90
89
assert isinstance (pass_evaluator , api .Evaluator )
91
90
out_args = api .Arguments .createArguments (0 )
92
-
93
- feeder = DataFeeder (self .__data_types__ , reader_dict )
94
-
91
+ feeder = DataFeeder (self .__data_types__ , feeding )
95
92
for pass_id in xrange (num_passes ):
96
93
event_handler (v2_event .BeginPass (pass_id ))
97
94
pass_evaluator .start ()
@@ -125,17 +122,8 @@ def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
125
122
event_handler (v2_event .EndPass (pass_id , evaluator = pass_evaluator ))
126
123
self .__gradient_machine__ .finish ()
127
124
128
- def default_reader_dict (self ):
129
- reader_dict = dict ()
130
- for i , tp in enumerate (self .__data_types__ ):
131
- reader_dict [tp [0 ]] = i
132
- return reader_dict
133
-
134
- def test (self , reader , reader_dict = None ):
135
- if reader_dict is None :
136
- reader_dict = self .default_reader_dict ()
137
-
138
- feeder = DataFeeder (self .__data_types__ , reader_dict )
125
+ def test (self , reader , feeding = None ):
126
+ feeder = DataFeeder (self .__data_types__ , feeding )
139
127
evaluator = self .__gradient_machine__ .makeEvaluator ()
140
128
out_args = api .Arguments .createArguments (0 )
141
129
evaluator .start ()
0 commit comments