Skip to content

Commit 8b833d5

Browse files
committed
Add load/save method for Parameter
1 parent 9601c2f commit 8b833d5

File tree

4 files changed

+22
-0
lines changed

4 files changed

+22
-0
lines changed

paddle/api/PaddleAPI.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ class Parameter {
551551
ParameterConfig* getConfig();
552552
void setValueUpdated();
553553

554+
bool save(const std::string& filename) const;
555+
556+
bool load(const std::string& filename) const;
557+
554558
private:
555559
static Parameter* createFromRawPtr(void* ptr);
556560
static Parameter* createFromSharedPtr(void* ptr);

paddle/api/Parameter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() {
7070
size_t Parameter::getID() const { return m->getPtr()->getID(); }
7171

7272
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
73+
74+
bool Parameter::save(const std::string& filename) const {
75+
return m->getPtr()->save(filename);
76+
}
77+
78+
bool Parameter::load(const std::string& filename) const {
79+
return m->getPtr()->load(filename);
80+
}

paddle/api/test/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
___fc_layer_0__.w0
2+
___fc_layer_0__.wbias
3+
_hidden1.w0
4+
_hidden1.wbias
5+
_hidden2.w0
6+
_hidden2.wbias

paddle/api/test/testGradientMachine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_create_gradient_machine(self):
4545
assert isinstance(val, swig_paddle.Vector)
4646
arr = numpy.full((len(val), ), 0.1, dtype="float32")
4747
val.copyFromNumpyArray(arr)
48+
self.assertTrue(param.save(param.getName()))
4849
param_config = param.getConfig().toProto()
4950
assert isinstance(param_config,
5051
paddle.proto.ParameterConfig_pb2.ParameterConfig)
@@ -92,6 +93,9 @@ def backward_callback(param_):
9293

9394
self.assertTrue(self.isCalled)
9495

96+
for param in machine.getParameters():
97+
self.assertTrue(param.load(param.getName()))
98+
9599
def test_train_one_pass(self):
96100
conf_file_path = './testTrainConfig.py'
97101
trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(

0 commit comments

Comments
 (0)