Skip to content

Commit 65e957c

Browse files
committed
Merge branch 'feature/mnist_train_api' of github.com:reyoung/Paddle into feature/mnist_train_api
2 parents f06b64f + a31ef0c commit 65e957c

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

demo/mnist/api_train.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ def init_parameter(network):
1919
assert isinstance(network, api.GradientMachine)
2020
for each_param in network.getParameters():
2121
assert isinstance(each_param, api.Parameter)
22-
array = each_param.getBuf(api.PARAMETER_VALUE).toNumpyArrayInplace()
23-
assert isinstance(array, np.ndarray)
24-
for i in xrange(len(array)):
25-
array[i] = np.random.uniform(-1.0, 1.0)
22+
array_size = len(each_param)
23+
array = np.random.uniform(-1.0, 1.0, array_size).astype('float32')
24+
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array)
2625

2726

2827
def generator_to_batch(generator, batch_size):
@@ -175,7 +174,7 @@ def updater_callback(param):
175174
for each_param in params:
176175
assert isinstance(each_param, api.Parameter)
177176
value = each_param.getBuf(api.PARAMETER_VALUE)
178-
value = value.toNumpyArrayInplace()
177+
value = value.copyToNumpyArray()
179178

180179
# Here, we could save parameter to every where you want
181180
print each_param.getName(), value

paddle/api/Paddle.swig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ namespace std {
9696
%rename(__getitem__) Vector::get;
9797
%rename(__setitem__) Vector::set;
9898
%rename(__len__) Vector::getSize;
99+
%rename(__len__) Parameter::getSize;
99100
%rename(__call__) ParameterTraverseCallback::apply;
100101
%rename(__repr__) Evaluator::toString;
101102

paddle/api/PaddleAPI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,8 @@ class Parameter {
550550
ParameterConfig* getConfig();
551551
void setValueUpdated();
552552

553+
size_t getSize() const;
554+
553555
private:
554556
static Parameter* createFromRawPtr(void* ptr);
555557
static Parameter* createFromSharedPtr(void* ptr);

paddle/api/Parameter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,5 @@ ParameterConfig* Parameter::getConfig() {
5656
size_t Parameter::getID() const { return m->getPtr()->getID(); }
5757

5858
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
59+
60+
size_t Parameter::getSize() const { return m->getPtr()->getSize(); }

0 commit comments

Comments
 (0)