@@ -25,9 +25,9 @@ void Pad<DEVICE_TYPE_CPU>(real* outputs,
25
25
const int inH,
26
26
const int inW,
27
27
const PadConf& pad) {
28
- int cstart = pad.channelStart , cend = pad.channelEnd ;
29
- int hstart = pad.heightStart , hend = pad.heightEnd ;
30
- int wstart = pad.widthStart , wend = pad.widthEnd ;
28
+ int cstart = pad.channel [ 0 ] , cend = pad.channel [ 1 ] ;
29
+ int hstart = pad.height [ 0 ] , hend = pad.height [ 1 ] ;
30
+ int wstart = pad.width [ 0 ] , wend = pad.width [ 1 ] ;
31
31
int outC = inC + cstart + cend;
32
32
int outH = inH + hstart + hend;
33
33
int outW = inW + wstart + wend;
@@ -51,9 +51,9 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
51
51
const int inH,
52
52
const int inW,
53
53
const PadConf& pad) {
54
- int cstart = pad.channelStart , cend = pad.channelEnd ;
55
- int hstart = pad.heightStart , hend = pad.heightEnd ;
56
- int wstart = pad.widthStart , wend = pad.widthEnd ;
54
+ int cstart = pad.channel [ 0 ] , cend = pad.channel [ 1 ] ;
55
+ int hstart = pad.height [ 0 ] , hend = pad.height [ 1 ] ;
56
+ int wstart = pad.width [ 0 ] , wend = pad.width [ 1 ] ;
57
57
int outC = inC + cstart + cend;
58
58
int outH = inH + hstart + hend;
59
59
int outW = inW + wstart + wend;
@@ -71,6 +71,12 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
71
71
}
72
72
}
73
73
74
+ static inline PadConf castToPadConf (const FuncConfig& conf) {
75
+ return {conf.get <std::vector<uint32_t >>(" channel" ),
76
+ conf.get <std::vector<uint32_t >>(" height" ),
77
+ conf.get <std::vector<uint32_t >>(" width" )};
78
+ }
79
+
74
80
/* *
75
81
* \brief Padding zeros to input according to the specify dimension.
76
82
* The struct pad_ contains the padding size in each dimension.
@@ -127,14 +133,7 @@ void PadGrad<DEVICE_TYPE_CPU>(real* inGrad,
127
133
template <DeviceType Device>
128
134
class PadFunc : public FunctionBase {
129
135
public:
130
- void init (const FuncConfig& config) override {
131
- pad_.channelStart = config.get <int >(" cstart" );
132
- pad_.channelEnd = config.get <int >(" cend" );
133
- pad_.heightStart = config.get <int >(" hstart" );
134
- pad_.heightEnd = config.get <int >(" hend" );
135
- pad_.widthStart = config.get <int >(" wstart" );
136
- pad_.widthEnd = config.get <int >(" wend" );
137
- }
136
+ void init (const FuncConfig& config) override { pad_ = castToPadConf (config); }
138
137
139
138
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
140
139
CHECK_EQ (1UL , inputs.size ());
@@ -175,14 +174,7 @@ class PadFunc : public FunctionBase {
175
174
template <DeviceType Device>
176
175
class PadGradFunc : public FunctionBase {
177
176
public:
178
- void init (const FuncConfig& config) override {
179
- pad_.channelStart = config.get <int >(" cstart" );
180
- pad_.channelEnd = config.get <int >(" cend" );
181
- pad_.heightStart = config.get <int >(" hstart" );
182
- pad_.heightEnd = config.get <int >(" hend" );
183
- pad_.widthStart = config.get <int >(" wstart" );
184
- pad_.widthEnd = config.get <int >(" wend" );
185
- }
177
+ void init (const FuncConfig& config) override { pad_ = castToPadConf (config); }
186
178
187
179
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
188
180
CHECK_EQ (1UL , inputs.size ());
0 commit comments