Skip to content

Commit 8fa2140

Browse files
authored
use paddle.text.datasets.UCIHousing (#1999)
Co-authored-by: co63oc <[email protected]>
1 parent 0c0c07a commit 8fa2140

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

backends/npu/tests/unittests/test_adadelta_op_npu.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,12 @@ def test_adadelta(self):
188188
rms_optimizer.minimize(avg_cost)
189189

190190
fetch_list = [avg_cost]
191-
train_reader = paddle.batch(
192-
paddle.dataset.uci_housing.train(), batch_size=1
193-
)
191+
uci_housing = paddle.text.datasets.UCIHousing(mode="train")
194192
feeder = base.DataFeeder(place=place, feed_list=[x, y])
195193
exe = base.Executor(place)
196194
exe.run(base.default_startup_program())
197-
for data in train_reader():
198-
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
195+
for data in uci_housing:
196+
exe.run(main, feed=feeder.feed([data]), fetch_list=fetch_list)
199197

200198
def test_raise_error(self):
201199
self.assertRaises(ValueError, paddle.optimizer.Adadelta, None)

backends/sdaa/tests/unittests/test_momentum_op_sdaa.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,12 @@ def test_momentum(self):
143143
rms_optimizer.minimize(avg_cost)
144144

145145
fetch_list = [avg_cost]
146-
train_reader = paddle.batch(
147-
paddle.dataset.uci_housing.train(), batch_size=1
148-
)
146+
uci_housing = paddle.text.datasets.UCIHousing(mode="train")
149147
feeder = base.DataFeeder(place=place, feed_list=[x, y])
150148
exe = base.Executor(place)
151149
exe.run(base.default_startup_program())
152-
for data in train_reader():
153-
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
150+
for data in uci_housing:
151+
exe.run(main, feed=feeder.feed([data]), fetch_list=fetch_list)
154152

155153
def test_raise_error(self):
156154
self.assertRaises(ValueError, paddle.optimizer.Momentum, learning_rate=None)
@@ -267,14 +265,12 @@ def test_momentum_static(self):
267265
momentum_optimizer.minimize(avg_cost)
268266

269267
fetch_list = [avg_cost]
270-
train_reader = paddle.batch(
271-
paddle.dataset.uci_housing.train(), batch_size=1
272-
)
268+
uci_housing = paddle.text.datasets.UCIHousing(mode="train")
273269
feeder = base.DataFeeder(place=place, feed_list=[x, y])
274270
exe = base.Executor(place)
275271
exe.run(base.default_startup_program())
276-
for data in train_reader():
277-
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
272+
for data in uci_housing:
273+
exe.run(main, feed=feeder.feed([data]), fetch_list=fetch_list)
278274

279275

280276
class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):

0 commit comments

Comments
 (0)