Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/api/Arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) {
a.cpuSequenceDims = m->cast<paddle::IVector>(vec->getSharedPtr());
}

float Arguments::sumCosts() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function names should be Camel Cased: https://google.github.io/styleguide/cppguide.html#Function_Names

是不是至少对于新加的函数,应该符合code style,这样至少提醒大家关注规范;现有的函数,可以以后写个工具重命名?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

麻烦review一下这个 baidu-adu/cpp-primer-digest#1
这个是Paddle目前的命名风格。

return paddle::Argument::sumCosts(m->outputs);
}

int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) {
auto& a = m->getArg(idx);
return a.getBatchSize();
Expand Down
6 changes: 6 additions & 0 deletions paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ class Arguments {
IVector* vec) throw(RangeError);
void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError);

float sumCosts() const;
Copy link
Collaborator

@wangkuiyi wangkuiyi Dec 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR的description里说明一下为什么需要增加这几个函数吧。加了之后能有什么好处:

  • sumCosts
  • load
  • save

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


private:
static Arguments* createByPaddleArgumentVector(void* ptr);
void* getInternalArgumentsPtr() const;
Expand Down Expand Up @@ -549,6 +551,10 @@ class Parameter {
ParameterConfig* getConfig();
void setValueUpdated();

bool save(const std::string& filename) const;

bool load(const std::string& filename) const;

private:
static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr);
Expand Down
8 changes: 8 additions & 0 deletions paddle/api/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() {
size_t Parameter::getID() const { return m->getPtr()->getID(); }

void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }

bool Parameter::save(const std::string& filename) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我之前一直迟迟没有approve这个PR的一个主要原因是,save/load(filename) 这样的methods不是一个好的设计。

首先这些methods不容易被unit test。除非我们有一个in-memory mock filesystem。但实际上我们不需要这么复杂的test facility。而且这些methods里的内容经常和网络传输methods里的内容重复——都是要 serialize/deserialize class memebers。

一个比较常见的设计是用 serialize/deserialize 来代替 save/load:

std::string serialize();
error deserialize(const std::string& input);

这样一来容易unit test,二来容易用于网络传输和磁盘I/O:

File f("/tmp/a");
Parameters ps;
f.write(ps.serialize());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非常有道理!现在的接口是不适合分布式化的,serialize/deserialize才能更方便传输。目前暴露的接口是老接口,仅仅是暴露出来,下一步提供c-api的时候可以考虑重构。

return m->getPtr()->save(filename);
}

bool Parameter::load(const std::string& filename) const {
return m->getPtr()->load(filename);
}
2 changes: 2 additions & 0 deletions paddle/api/test/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.w0
*.wbias
2 changes: 2 additions & 0 deletions paddle/api/test/testArguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def test_load_arguments(self):
args = swig_paddle.Arguments.createArguments(1)
args.setSlotValue(0, m)

self.assertAlmostEqual(27.0, args.sumCosts())

mat = args.getSlotValue(0)
assert isinstance(mat, swig_paddle.Matrix)
np_mat = mat.toNumpyMatInplace()
Expand Down
4 changes: 4 additions & 0 deletions paddle/api/test/testGradientMachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_create_gradient_machine(self):
assert isinstance(val, swig_paddle.Vector)
arr = numpy.full((len(val), ), 0.1, dtype="float32")
val.copyFromNumpyArray(arr)
self.assertTrue(param.save(param.getName()))
param_config = param.getConfig().toProto()
assert isinstance(param_config,
paddle.proto.ParameterConfig_pb2.ParameterConfig)
Expand Down Expand Up @@ -92,6 +93,9 @@ def backward_callback(param_):

self.assertTrue(self.isCalled)

for param in machine.getParameters():
self.assertTrue(param.load(param.getName()))

def test_train_one_pass(self):
conf_file_path = './testTrainConfig.py'
trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(
Expand Down