Skip to content

Commit 43d92fa

Browse files
committed
Make api_train_v2 runnable
1 parent 7293c82 commit 43d92fa

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

demo/mnist/api_train_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ def main():
2020

2121
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
2222

23-
trainer = paddle.trainer.SGD(topology=cost,
23+
trainer = paddle.trainer.SGD(cost=cost,
2424
parameters=parameters,
2525
update_equation=adam_optimizer)
2626

2727
def event_handler(event):
2828
if isinstance(event, paddle.event.EndIteration):
2929
if event.batch_id % 1000 == 0:
3030
result = trainer.test(reader=paddle.reader.batched(
31-
paddle.dataset.mnist.test_creator(), batch_size=256))
31+
paddle.dataset.mnist.test(), batch_size=256))
3232

3333
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
3434
event.pass_id, event.batch_id, event.cost, event.metrics,

python/paddle/v2/dataset/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
1111
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
12-
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
12+
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
1313
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
14-
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
14+
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
1515
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
1616
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
1717
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'

python/paddle/v2/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, cost, parameters, update_equation):
6161
self.__topology__ = topology
6262
self.__parameters__ = parameters
6363
self.__topology_in_proto__ = topology.proto()
64-
self.__data_types__ = topology.data_layers()
64+
self.__data_types__ = topology.data_type()
6565
gm = api.GradientMachine.createFromConfigProto(
6666
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
6767
self.__optimizer__.enable_types())

0 commit comments

Comments
 (0)