@@ -143,14 +143,12 @@ def test_momentum(self):
143
143
rms_optimizer .minimize (avg_cost )
144
144
145
145
fetch_list = [avg_cost ]
146
- train_reader = paddle .batch (
147
- paddle .dataset .uci_housing .train (), batch_size = 1
148
- )
146
+ uci_housing = paddle .text .datasets .UCIHousing (mode = "train" )
149
147
feeder = base .DataFeeder (place = place , feed_list = [x , y ])
150
148
exe = base .Executor (place )
151
149
exe .run (base .default_startup_program ())
152
- for data in train_reader () :
153
- exe .run (main , feed = feeder .feed (data ), fetch_list = fetch_list )
150
+ for data in uci_housing :
151
+ exe .run (main , feed = feeder .feed ([ data ] ), fetch_list = fetch_list )
154
152
155
153
def test_raise_error (self ):
156
154
self .assertRaises (ValueError , paddle .optimizer .Momentum , learning_rate = None )
@@ -267,14 +265,12 @@ def test_momentum_static(self):
267
265
momentum_optimizer .minimize (avg_cost )
268
266
269
267
fetch_list = [avg_cost ]
270
- train_reader = paddle .batch (
271
- paddle .dataset .uci_housing .train (), batch_size = 1
272
- )
268
+ uci_housing = paddle .text .datasets .UCIHousing (mode = "train" )
273
269
feeder = base .DataFeeder (place = place , feed_list = [x , y ])
274
270
exe = base .Executor (place )
275
271
exe .run (base .default_startup_program ())
276
- for data in train_reader () :
277
- exe .run (main , feed = feeder .feed (data ), fetch_list = fetch_list )
272
+ for data in uci_housing :
273
+ exe .run (main , feed = feeder .feed ([ data ] ), fetch_list = fetch_list )
278
274
279
275
280
276
class TestMomentumOpVsMomentumOpWithDecayAPI (unittest .TestCase ):
0 commit comments