Skip to content

Commit 9b56074

Browse files
authored
Merge pull request #5705 from tensor-tang/mkldnn_concat
enable mkldnn_concat layer
2 parents ba86885 + 88feb51 commit 9b56074

File tree

6 files changed

+393
-4
lines changed

6 files changed

+393
-4
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "MKLDNNConcatLayer.h"
16+
17+
using namespace mkldnn; // NOLINT
18+
typedef memory::format format;
19+
20+
namespace paddle {
21+
22+
REGISTER_LAYER(mkldnn_concat, MKLDNNConcatLayer);
23+
24+
bool MKLDNNConcatLayer::init(const LayerMap& layerMap,
25+
const ParameterMap& parameterMap) {
26+
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
27+
return false;
28+
}
29+
CHECK_GT(inputLayers_.size(), 1UL);
30+
CHECK(!biasParameter_);
31+
return true;
32+
}
33+
34+
void MKLDNNConcatLayer::reshape(
35+
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
36+
reshapeInput(bs, ih, iw);
37+
ic = inputLayers_[0]->getSize() / ih / iw;
38+
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
39+
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw);
40+
CHECK_GT(inputLayers_.size(), 1UL);
41+
channels_.resize(inputLayers_.size());
42+
channels_[0] = ic;
43+
// need change the output channel, so use oc_ instead
44+
// TODO(TJ): change API, use &oc
45+
oc_ = ic;
46+
for (size_t i = 1; i < inputLayers_.size(); i++) {
47+
int batchsize, height, witdh;
48+
reshapeInput(batchsize, height, witdh, i);
49+
CHECK_EQ(bs, batchsize);
50+
CHECK_EQ(ih, height);
51+
CHECK_EQ(iw, witdh);
52+
53+
channels_[i] = inputLayers_[i]->getSize() / height / witdh;
54+
CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize());
55+
oc_ += channels_[i];
56+
}
57+
oh = ih;
58+
ow = iw;
59+
reshapeOutput(oh, ow);
60+
resizeOutput(bs, oc_ * oh * ow);
61+
}
62+
63+
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
64+
MKLDNNMatrixPtr& in,
65+
MKLDNNMatrixPtr& wgt,
66+
MKLDNNMatrixPtr& bias,
67+
MKLDNNMatrixPtr& out) {
68+
resetFwdBuffers(inVals_, out);
69+
in = inVals_[0];
70+
71+
std::shared_ptr<concat::primitive_desc> fwdPD;
72+
resetFwdPD(fwdPD, inVals_, out);
73+
74+
resetFwdPipeline(pipeline, fwdPD, inVals_, out);
75+
}
76+
77+
void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline,
78+
MKLDNNMatrixPtr& in,
79+
MKLDNNMatrixPtr& wgt,
80+
MKLDNNMatrixPtr& bias,
81+
MKLDNNMatrixPtr& out) {
82+
resetBwdBuffers(inGrads_, out);
83+
in = inGrads_[0];
84+
85+
resetBwdPipeline(pipeline, bwds_, inGrads_, out);
86+
}
87+
88+
void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
89+
MKLDNNMatrixPtr& out) {
90+
inputs.resize(inputLayers_.size());
91+
bool has8c = false, has16c = false, hasnc = false;
92+
for (size_t i = 0; i < inputs.size(); i++) {
93+
// resetInValue will use ic_ so temporary change as current input's channel
94+
// TODO(TJ): change ic_ as vector then can remove channels_
95+
ic_ = channels_[i];
96+
resetInValue(inputs[i], nullptr, i);
97+
CHECK(inputs[i]);
98+
auto dm = inputs[i]->getDims();
99+
// inputs format can be different, but ndims must equal
100+
CHECK(i == 0 || dm.size() == inputs[0]->getDims().size());
101+
CHECK_EQ(bs_, dm[0]);
102+
CHECK_EQ(channels_[i], dm[1]);
103+
if (dm.size() > 2) {
104+
CHECK_EQ(ih_, dm[2]);
105+
CHECK_EQ(iw_, dm[3]);
106+
}
107+
if (inputs[i]->getFormat() == format::nc) {
108+
hasnc = true;
109+
}
110+
if (inputs[i]->getFormat() == format::nChw8c) {
111+
has8c = true;
112+
}
113+
if (inputs[i]->getFormat() == format::nChw16c) {
114+
has16c = true;
115+
}
116+
}
117+
// change back, ic_ always save the input 0 size
118+
ic_ = channels_[0];
119+
120+
format outFmt;
121+
if (has16c && oc_ % 16 == 0) {
122+
outFmt = format::nChw16c;
123+
} else if (has8c && oc_ % 8 == 0) {
124+
outFmt = format::nChw8c;
125+
} else if (hasnc) {
126+
CHECK(oh_ == 1 && ow_ == 1);
127+
outFmt = format::nc;
128+
} else {
129+
outFmt = format::nchw;
130+
}
131+
memory::dims outDims =
132+
hasnc ? memory::dims{bs_, oc_} : memory::dims{bs_, oc_, oh_, ow_};
133+
auto outPD = MKLDNNMatrix::createPrimitiveDesc(outDims, outFmt, engine_);
134+
resetOutValue(out, outPD);
135+
}
136+
137+
void MKLDNNConcatLayer::resetFwdPD(std::shared_ptr<concat::primitive_desc>& pd,
138+
std::vector<MKLDNNMatrixPtr>& inputs,
139+
MKLDNNMatrixPtr out) {
140+
std::vector<memory::primitive_desc> srcPDs;
141+
for (size_t i = 0; i < inputs.size(); i++) {
142+
srcPDs.push_back(inputs[i]->getPrimitiveDesc());
143+
}
144+
CHECK(out);
145+
pd.reset(new concat::primitive_desc(out->getMemoryDesc(), axis_, srcPDs));
146+
CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
147+
}
148+
149+
void MKLDNNConcatLayer::resetFwdPipeline(
150+
std::vector<primitive>& pipeline,
151+
std::shared_ptr<concat::primitive_desc>& pd,
152+
std::vector<MKLDNNMatrixPtr>& inputs,
153+
MKLDNNMatrixPtr& out) {
154+
std::vector<primitive::at> srcs;
155+
for (size_t i = 0; i < inputs.size(); i++) {
156+
srcs.push_back(*(inputs[i]));
157+
}
158+
fwd_.reset(new concat(*pd, srcs, *out));
159+
pipeline.push_back(*fwd_);
160+
}
161+
162+
void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
163+
MKLDNNMatrixPtr& out) {
164+
CHECK(outVal_);
165+
resetOutGrad(out, outVal_->getPrimitiveDesc());
166+
CHECK(out);
167+
168+
inputs.resize(inputLayers_.size());
169+
for (size_t i = 0; i < inputs.size(); i++) {
170+
CHECK(inVals_[i]);
171+
// resetInGrad will use inVal_
172+
// TODO(TJ): change move inVals_ to MKLDNNLayer ans remove inVal_
173+
inVal_ = inVals_[i];
174+
resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i);
175+
CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc());
176+
}
177+
// change back, inVal_ always save the input 0
178+
inVal_ = inVals_[0];
179+
}
180+
181+
void MKLDNNConcatLayer::resetBwdPipeline(
182+
std::vector<mkldnn::primitive>& pipeline,
183+
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
184+
std::vector<MKLDNNMatrixPtr>& inputs,
185+
MKLDNNMatrixPtr& out) {
186+
// reset the backward primitives
187+
memory::dims offsets = {0, 0, 0, 0};
188+
prims.resize(inputs.size());
189+
CHECK_EQ(inputs.size(), channels_.size());
190+
for (size_t i = 0; i < inputs.size(); i++) {
191+
auto viewPD = view::primitive_desc(
192+
out->getPrimitiveDesc(), inputs[i]->getDims(), offsets);
193+
auto bwdPD = reorder::primitive_desc(viewPD.dst_primitive_desc(),
194+
inputs[i]->getPrimitiveDesc());
195+
prims[i].reset(new reorder(bwdPD, *out, *(inputs[i])));
196+
offsets[axis_] += channels_[i];
197+
// push to pipeline
198+
pipeline.push_back(*prims[i]);
199+
}
200+
}
201+
202+
} // namespace paddle
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "MKLDNNLayer.h"
18+
#include "mkldnn.hpp"
19+
20+
namespace paddle {
21+
22+
/**
23+
* @brief A subclass of MKLDNNLayer Concatenate layer.
24+
*
25+
* The config file api is mkldnn_concat
26+
*/
27+
class MKLDNNConcatLayer : public MKLDNNLayer {
28+
protected:
29+
std::vector<MKLDNNMatrixPtr> inVals_;
30+
std::vector<MKLDNNMatrixPtr> inGrads_;
31+
std::vector<std::shared_ptr<mkldnn::primitive>> bwds_;
32+
// input channel numbers
33+
std::vector<int> channels_;
34+
35+
// concat_dimension in MKLDNN
36+
// if axis_ == 0, concat batchsize
37+
// if axis_ == 1, concat channel (default)
38+
int axis_;
39+
40+
public:
41+
explicit MKLDNNConcatLayer(const LayerConfig& config)
42+
: MKLDNNLayer(config), axis_(1) {}
43+
44+
~MKLDNNConcatLayer() {}
45+
46+
bool init(const LayerMap& layerMap,
47+
const ParameterMap& parameterMap) override;
48+
49+
void reshape(
50+
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
51+
52+
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
53+
MKLDNNMatrixPtr& in,
54+
MKLDNNMatrixPtr& wgt,
55+
MKLDNNMatrixPtr& bias,
56+
MKLDNNMatrixPtr& out) override;
57+
58+
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
59+
MKLDNNMatrixPtr& in,
60+
MKLDNNMatrixPtr& wgt,
61+
MKLDNNMatrixPtr& bias,
62+
MKLDNNMatrixPtr& out) override;
63+
64+
void printSizeInfo() override {
65+
CHECK_EQ(channels_.size(), inputLayers_.size());
66+
for (size_t i = 0; i < channels_.size(); ++i) {
67+
VLOG(MKLDNN_SIZES) << "Input " << i << ", " << inputLayers_[i]->getName()
68+
<< ": " << bs_ << ", " << channels_[i] << ", " << ih_
69+
<< ", " << iw_;
70+
}
71+
VLOG(MKLDNN_SIZES) << "Output: " << bs_ << ", " << oc_ << ", " << oh_
72+
<< ", " << ow_;
73+
}
74+
75+
void printValueFormat() override {
76+
for (size_t i = 0; i < inVals_.size(); ++i) {
77+
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
78+
<< ": " << inVals_[i]->getFormat() << " >>>";
79+
}
80+
if (outVal_) {
81+
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> ";
82+
}
83+
if (extOutVal_) {
84+
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
85+
}
86+
}
87+
88+
void printGradFormat() override {
89+
if (extOutGrad_) {
90+
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
91+
}
92+
if (outGrad_) {
93+
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
94+
}
95+
for (size_t i = 0; i < inGrads_.size(); ++i) {
96+
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
97+
<< ": " << inGrads_[i]->getFormat() << "<<<";
98+
}
99+
}
100+
101+
protected:
102+
/**
103+
* Forward functions: reset buffers(inputs, output, bias),
104+
* reset primitive descriptor,
105+
* reset pipeline.
106+
*/
107+
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
108+
MKLDNNMatrixPtr& out);
109+
void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
110+
std::vector<MKLDNNMatrixPtr>& inputs,
111+
MKLDNNMatrixPtr out);
112+
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
113+
std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
114+
std::vector<MKLDNNMatrixPtr>& inputs,
115+
MKLDNNMatrixPtr& out);
116+
117+
/**
118+
* Backward functions: reset buffers(inputs, output, bias)
119+
* reset primitives and pipeline
120+
*/
121+
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
122+
MKLDNNMatrixPtr& out);
123+
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
124+
std::vector<std::shared_ptr<mkldnn::primitive>>& prims,
125+
std::vector<MKLDNNMatrixPtr>& inputs,
126+
MKLDNNMatrixPtr& out);
127+
};
128+
129+
} // namespace paddle

paddle/gserver/layers/MKLDNNLayer.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,11 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) {
138138
}
139139
}
140140

141-
void MKLDNNLayer::reshapeInput(int& batchsize, int& height, int& width) {
142-
const Argument& input = inputLayers_[0]->getOutput();
141+
void MKLDNNLayer::reshapeInput(int& batchsize,
142+
int& height,
143+
int& width,
144+
size_t inputIdx) {
145+
const Argument& input = inputLayers_[inputIdx]->getOutput();
143146
batchsize = input.getBatchSize();
144147
int h = input.getFrameHeight();
145148
int w = input.getFrameWidth();

paddle/gserver/layers/MKLDNNLayer.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ class MKLDNNLayer : public Layer {
178178
/**
179179
* reshape the input image sizes and input batchsize
180180
*/
181-
void reshapeInput(int& batchsize, int& height, int& width);
181+
void reshapeInput(int& batchsize,
182+
int& height,
183+
int& width,
184+
size_t inputIdx = 0);
182185

183186
/**
184187
* reshape output image sizes

0 commit comments

Comments
 (0)