Skip to content

Commit a765c7c

Browse files
authored
Merge pull request #1013 from reyoung/feature/add_sum_cost_in_args
Add some functions to PaddleAPI.h
2 parents cb0a1e2 + b23d99d commit a765c7c

File tree

6 files changed

+26
-0
lines changed

6 files changed

+26
-0
lines changed

paddle/api/Arguments.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) {
137137
a.cpuSequenceDims = m->cast<paddle::IVector>(vec->getSharedPtr());
138138
}
139139

140+
float Arguments::sumCosts() const {
141+
return paddle::Argument::sumCosts(m->outputs);
142+
}
143+
140144
int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) {
141145
auto& a = m->getArg(idx);
142146
return a.getBatchSize();

paddle/api/PaddleAPI.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ class Arguments {
450450
IVector* vec) throw(RangeError);
451451
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);
452452

453+
float sumCosts() const;
454+
453455
private:
454456
static Arguments* createByPaddleArgumentVector(void* ptr);
455457
void* getInternalArgumentsPtr() const;
@@ -546,6 +548,10 @@ class Parameter {
546548
ParameterConfig* getConfig();
547549
void setValueUpdated();
548550

551+
bool save(const std::string& filename) const;
552+
553+
bool load(const std::string& filename) const;
554+
549555
size_t getSize() const;
550556

551557
private:

paddle/api/Parameter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,12 @@ size_t Parameter::getID() const { return m->getPtr()->getID(); }
5757

5858
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
5959

60+
bool Parameter::save(const std::string& filename) const {
61+
return m->getPtr()->save(filename);
62+
}
63+
64+
bool Parameter::load(const std::string& filename) const {
65+
return m->getPtr()->load(filename);
66+
}
67+
6068
size_t Parameter::getSize() const { return m->getPtr()->getSize(); }

paddle/api/test/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.w0
2+
*.wbias

paddle/api/test/testArguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def test_load_arguments(self):
2222
args = swig_paddle.Arguments.createArguments(1)
2323
args.setSlotValue(0, m)
2424

25+
self.assertAlmostEqual(27.0, args.sumCosts())
26+
2527
mat = args.getSlotValue(0)
2628
assert isinstance(mat, swig_paddle.Matrix)
2729
np_mat = mat.toNumpyMatInplace()

paddle/api/test/testGradientMachine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_create_gradient_machine(self):
4545
assert isinstance(val, swig_paddle.Vector)
4646
arr = numpy.full((len(val), ), 0.1, dtype="float32")
4747
val.copyFromNumpyArray(arr)
48+
self.assertTrue(param.save(param.getName()))
4849
param_config = param.getConfig().toProto()
4950
assert isinstance(param_config,
5051
paddle.proto.ParameterConfig_pb2.ParameterConfig)
@@ -92,6 +93,9 @@ def backward_callback(param_):
9293

9394
self.assertTrue(self.isCalled)
9495

96+
for param in machine.getParameters():
97+
self.assertTrue(param.load(param.getName()))
98+
9599
def test_train_one_pass(self):
96100
conf_file_path = './testTrainConfig.py'
97101
trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(

0 commit comments

Comments
 (0)