Skip to content

Commit 3fffcd4

Browse files
authored
Merge pull request #1755 from reyoung/feature/add_any_in_paddle
Using linb::any/std::any instead of FunctionConfig
2 parents 4b1b599 + 36524bb commit 3fffcd4

File tree

9 files changed

+116
-126
lines changed

9 files changed

+116
-126
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ include(external/python) # download, build, install python
6464
include(external/openblas) # download, build, install openblas
6565
include(external/swig) # download, build, install swig
6666
include(external/warpctc) # download, build, install warpctc
67+
include(external/any) # download libn::any
6768

6869
include(package) # set paddle packages
6970
include(cpplint) # set paddle c++ style

cmake/external/any.cmake

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
INCLUDE(ExternalProject)
2+
3+
SET(ANY_SOURCE_DIR ${THIRD_PARTY_PATH}/any)
4+
5+
INCLUDE_DIRECTORIES(${ANY_SOURCE_DIR}/src/linb_any)
6+
7+
ExternalProject_Add(
8+
linb_any
9+
${EXTERNAL_PROJECT_LOG_ARGS}
10+
GIT_REPOSITORY "https://github.com/thelink2012/any.git"
11+
GIT_TAG "8fef1e93710a0edf8d7658999e284a1142c4c020"
12+
PREFIX ${ANY_SOURCE_DIR}
13+
UPDATE_COMMAND ""
14+
CONFIGURE_COMMAND ""
15+
BUILD_COMMAND ""
16+
INSTALL_COMMAND ""
17+
TEST_COMMAND ""
18+
)
19+
20+
add_definitions(-DANY_IMPL_ANY_CAST_MOVEABLE)

paddle/function/Function.cpp

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,66 +16,6 @@ limitations under the License. */
1616

1717
namespace paddle {
1818

19-
template <>
20-
size_t FuncConfig::get<size_t>(const std::string& key) const {
21-
auto it = valueMap_.find(key);
22-
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
23-
return it->second.s;
24-
}
25-
26-
template <>
27-
real FuncConfig::get<real>(const std::string& key) const {
28-
auto it = valueMap_.find(key);
29-
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
30-
return it->second.r;
31-
}
32-
33-
template <>
34-
int FuncConfig::get<int>(const std::string& key) const {
35-
auto it = valueMap_.find(key);
36-
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
37-
return it->second.i;
38-
}
39-
40-
template <>
41-
bool FuncConfig::get<bool>(const std::string& key) const {
42-
auto it = valueMap_.find(key);
43-
CHECK(it != valueMap_.end()) << "Cannot find value: '" << key << "'";
44-
return it->second.b;
45-
}
46-
47-
template <>
48-
FuncConfig& FuncConfig::set<size_t>(const std::string& key, size_t v) {
49-
CHECK_EQ(static_cast<int>(valueMap_.count(key)), 0) << "Duplicated value: "
50-
<< key;
51-
valueMap_[key].s = v;
52-
return *this;
53-
}
54-
55-
template <>
56-
FuncConfig& FuncConfig::set<real>(const std::string& key, real v) {
57-
CHECK_EQ(static_cast<int>(valueMap_.count(key)), 0) << "Duplicated value: "
58-
<< key;
59-
valueMap_[key].r = v;
60-
return *this;
61-
}
62-
63-
template <>
64-
FuncConfig& FuncConfig::set<int>(const std::string& key, int v) {
65-
CHECK_EQ(static_cast<int>(valueMap_.count(key)), 0) << "Duplicated value: "
66-
<< key;
67-
valueMap_[key].i = v;
68-
return *this;
69-
}
70-
71-
template <>
72-
FuncConfig& FuncConfig::set<bool>(const std::string& key, bool v) {
73-
CHECK_EQ(static_cast<int>(valueMap_.count(key)), 0) << "Duplicated value: "
74-
<< key;
75-
valueMap_[key].b = v;
76-
return *this;
77-
}
78-
7919
void BufferArgs::addArg(const Matrix& arg,
8020
const TensorShape& shape,
8121
ArgType argType) {

paddle/function/Function.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,49 @@ limitations under the License. */
1818
#include <vector>
1919
#include "BufferArg.h"
2020
#include "paddle/math/Matrix.h"
21+
#include "paddle/utils/Any.h"
2122
#include "paddle/utils/ClassRegistrar.h"
23+
#include "paddle/utils/Error.h"
2224

2325
namespace paddle {
2426

2527
/**
2628
* Function Configuration.
2729
* The argument type of Function::init.
28-
* Follow-up will consider moving this data structure to Proto inside.
2930
*/
3031
class FuncConfig {
3132
public:
32-
union value {
33-
size_t s;
34-
real r;
35-
int i;
36-
bool b;
37-
};
38-
3933
template <typename T>
40-
T get(const std::string& key) const;
34+
T get(const std::string& key, Error* err = nullptr) const {
35+
try {
36+
return any_cast<T>(valueMap_.at(key));
37+
} catch (std::exception& e) { // could be cast or out of range exception.
38+
if (err) {
39+
*err = Error(e.what());
40+
} else {
41+
LOG(FATAL) << "Cannot get key " << key << "with error " << e.what();
42+
}
43+
return T();
44+
}
45+
}
4146

4247
template <typename T>
43-
FuncConfig& set(const std::string& key, T v);
48+
FuncConfig& set(const std::string& key, T v, Error* err = nullptr) {
49+
auto it = valueMap_.find(key);
50+
if (it != valueMap_.end()) { // already contains key.
51+
if (err) {
52+
*err = Error("Key %s is already set in FuncConfig", key.c_str());
53+
} else {
54+
LOG(FATAL) << "Key " << key << " is already set in FuncConfig.";
55+
}
56+
return *this;
57+
}
58+
valueMap_[key] = any(v);
59+
return *this;
60+
}
4461

4562
protected:
46-
std::map<std::string, value> valueMap_;
63+
mutable std::unordered_map<std::string, any> valueMap_;
4764
};
4865

4966
/**

paddle/function/PadOp.cpp

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ void Pad<DEVICE_TYPE_CPU>(real* outputs,
2525
const int inH,
2626
const int inW,
2727
const PadConf& pad) {
28-
int cstart = pad.channelStart, cend = pad.channelEnd;
29-
int hstart = pad.heightStart, hend = pad.heightEnd;
30-
int wstart = pad.widthStart, wend = pad.widthEnd;
28+
int cstart = pad.channel[0], cend = pad.channel[1];
29+
int hstart = pad.height[0], hend = pad.height[1];
30+
int wstart = pad.width[0], wend = pad.width[1];
3131
int outC = inC + cstart + cend;
3232
int outH = inH + hstart + hend;
3333
int outW = inW + wstart + wend;
@@ -51,9 +51,9 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
5151
const int inH,
5252
const int inW,
5353
const PadConf& pad) {
54-
int cstart = pad.channelStart, cend = pad.channelEnd;
55-
int hstart = pad.heightStart, hend = pad.heightEnd;
56-
int wstart = pad.widthStart, wend = pad.widthEnd;
54+
int cstart = pad.channel[0], cend = pad.channel[1];
55+
int hstart = pad.height[0], hend = pad.height[1];
56+
int wstart = pad.width[0], wend = pad.width[1];
5757
int outC = inC + cstart + cend;
5858
int outH = inH + hstart + hend;
5959
int outW = inW + wstart + wend;
@@ -71,6 +71,12 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
7171
}
7272
}
7373

74+
static inline PadConf castToPadConf(const FuncConfig& conf) {
75+
return {conf.get<std::vector<uint32_t>>("channel"),
76+
conf.get<std::vector<uint32_t>>("height"),
77+
conf.get<std::vector<uint32_t>>("width")};
78+
}
79+
7480
/**
7581
* \brief Padding zeros to input according to the specify dimension.
7682
* The struct pad_ contains the padding size in each dimension.
@@ -127,14 +133,7 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
127133
template <DeviceType Device>
128134
class PadFunc : public FunctionBase {
129135
public:
130-
void init(const FuncConfig& config) override {
131-
pad_.channelStart = config.get<int>("cstart");
132-
pad_.channelEnd = config.get<int>("cend");
133-
pad_.heightStart = config.get<int>("hstart");
134-
pad_.heightEnd = config.get<int>("hend");
135-
pad_.widthStart = config.get<int>("wstart");
136-
pad_.widthEnd = config.get<int>("wend");
137-
}
136+
void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
138137

139138
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
140139
CHECK_EQ(1UL, inputs.size());
@@ -175,14 +174,7 @@ class PadFunc : public FunctionBase {
175174
template <DeviceType Device>
176175
class PadGradFunc : public FunctionBase {
177176
public:
178-
void init(const FuncConfig& config) override {
179-
pad_.channelStart = config.get<int>("cstart");
180-
pad_.channelEnd = config.get<int>("cend");
181-
pad_.heightStart = config.get<int>("hstart");
182-
pad_.heightEnd = config.get<int>("hend");
183-
pad_.widthStart = config.get<int>("wstart");
184-
pad_.widthEnd = config.get<int>("wend");
185-
}
177+
void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
186178

187179
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
188180
CHECK_EQ(1UL, inputs.size());

paddle/function/PadOp.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,12 @@ limitations under the License. */
1919
namespace paddle {
2020

2121
struct PadConf {
22-
/// how many values to add before the data along channel dimension.
23-
int channelStart;
24-
/// how many values to add after the data along channel dimension.
25-
int channelEnd;
26-
/// how many values to add before the data along height dimension.
27-
int heightStart;
28-
/// how many values to add after the data along height dimension.
29-
int heightEnd;
30-
/// how many values to add before the data along width dimension.
31-
int widthStart;
32-
/// how many values to add after the data along width dimension.
33-
int widthEnd;
22+
/// how many values to add before/after the data along channel dimension.
23+
std::vector<uint32_t> channel;
24+
/// how many values to add before/after the data along height dimension.
25+
std::vector<uint32_t> height;
26+
/// how many values to add before/after the data along width dimension.
27+
std::vector<uint32_t> width;
3428
};
3529

3630
/**

paddle/gserver/layers/PadLayer.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,34 +36,25 @@ bool PadLayer::init(const LayerMap& layerMap,
3636
CHECK_EQ(2, pad_conf.pad_c_size());
3737
CHECK_EQ(2, pad_conf.pad_h_size());
3838
CHECK_EQ(2, pad_conf.pad_w_size());
39-
padc_.push_back(pad_conf.pad_c(0));
40-
padc_.push_back(pad_conf.pad_c(1));
41-
padh_.push_back(pad_conf.pad_h(0));
42-
padh_.push_back(pad_conf.pad_h(1));
43-
padw_.push_back(pad_conf.pad_w(0));
44-
padw_.push_back(pad_conf.pad_w(1));
39+
padc_ = {pad_conf.pad_c(0), pad_conf.pad_c(1)};
40+
padh_ = {pad_conf.pad_h(0), pad_conf.pad_h(1)};
41+
padw_ = {pad_conf.pad_w(0), pad_conf.pad_w(1)};
4542

4643
outDims_ = TensorShape(4);
4744
setOutDims(0);
4845

4946
createFunction(forward_,
5047
"Pad",
5148
FuncConfig()
52-
.set("cstart", padc_[0])
53-
.set("cend", padc_[1])
54-
.set("hstart", padh_[0])
55-
.set("hend", padh_[1])
56-
.set("wstart", padw_[0])
57-
.set("wend", padw_[1]));
49+
.set("channel", padc_)
50+
.set("height", padh_)
51+
.set("width", padw_));
5852
createFunction(backward_,
5953
"PadGrad",
6054
FuncConfig()
61-
.set("cstart", padc_[0])
62-
.set("cend", padc_[1])
63-
.set("hstart", padh_[0])
64-
.set("hend", padh_[1])
65-
.set("wstart", padw_[0])
66-
.set("wend", padw_[1]));
55+
.set("channel", padc_)
56+
.set("height", padh_)
57+
.set("width", padw_));
6758

6859
return true;
6960
}

paddle/gserver/layers/PadLayer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ class PadLayer : public Layer {
3838
void setOutDims(const size_t batchSize);
3939
void setTensorDim(const size_t batchSize);
4040

41-
std::vector<int> padc_;
42-
std::vector<int> padh_;
43-
std::vector<int> padw_;
41+
std::vector<uint32_t> padc_;
42+
std::vector<uint32_t> padh_;
43+
std::vector<uint32_t> padw_;
4444
TensorShape inDims_;
4545
TensorShape outDims_;
4646
};

paddle/utils/Any.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#if __cplusplus > 201402L
17+
#include <any>
18+
19+
namespace paddle {
20+
// using std::any for C++ 17
21+
using std::any;
22+
using std::any_cast;
23+
using std::bad_any_cast;
24+
} // namespace paddle
25+
26+
#else
27+
#include <any.hpp>
28+
29+
namespace paddle {
30+
// use linb::any for C++ 11
31+
using linb::any;
32+
using linb::any_cast;
33+
using linb::bad_any_cast;
34+
} // namespace paddle
35+
#endif

0 commit comments

Comments
 (0)