Skip to content

Commit c9a9657

Browse files
authored
py_test and test_image_classification_train support argument (#5934)
* py_test support argument, test_image_classification_train support argument * use REMOVE_ITEM to rm item from list in cmake
1 parent 9de8ce1 commit c9a9657

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

cmake/generic.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,11 +459,11 @@ function(py_test TARGET_NAME)
459459
if(WITH_TESTING)
460460
set(options STATIC static SHARED shared)
461461
set(oneValueArgs "")
462-
set(multiValueArgs SRCS DEPS)
463-
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
462+
set(multiValueArgs SRCS DEPS ARGS)
463+
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
464464
add_test(NAME ${TARGET_NAME}
465465
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
466-
${PYTHON_EXECUTABLE} ${py_test_SRCS}
466+
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
467467
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
468468
endif()
469469
endfunction()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
22
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
3+
4+
list(REMOVE_ITEM TEST_OPS test_image_classification_train)
5+
py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
6+
py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
7+
8+
# default test
39
foreach(src ${TEST_OPS})
410
py_test(${src} SRCS ${src}.py)
511
endforeach()

python/paddle/v2/fluid/tests/book/test_image_classification_train.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import print_function
2+
23
import numpy as np
34
import paddle.v2 as paddle
45
import paddle.v2.fluid as fluid
6+
import sys
57

68

79
def resnet_cifar10(input, depth=32):
@@ -80,11 +82,18 @@ def conv_block(input, num_filter, groups, dropouts):
8082
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
8183
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
8284

83-
# Add neural network config
84-
# option 1. resnet
85-
# net = resnet_cifar10(images, 32)
86-
# option 2. vgg
87-
net = vgg16_bn_drop(images)
85+
net_type = "vgg"
86+
if len(sys.argv) >= 2:
87+
net_type = sys.argv[1]
88+
89+
if net_type == "vgg":
90+
print("train vgg net")
91+
net = vgg16_bn_drop(images)
92+
elif net_type == "resnet":
93+
print("train resnet")
94+
net = resnet_cifar10(images, 32)
95+
else:
96+
raise ValueError("%s network is not supported" % net_type)
8897

8998
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
9099
cost = fluid.layers.cross_entropy(input=predict, label=label)

0 commit comments

Comments
 (0)