Skip to content

Commit 54205c9

Browse files
committed
add MKLDNNLRNLayer
1 parent 4786ad1 commit 54205c9

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "MKLDNNLRNLayer.h"
16+
#include "paddle/utils/Logging.h"
17+
18+
using namespace mkldnn; // NOLINT
19+
typedef memory::format format;
20+
21+
namespace paddle {
22+
23+
REGISTER_LAYER(mkldnn_lrn, MKLDNNLRNLayer);
24+
25+
bool MKLDNNLRNLayer::init(const LayerMap& layerMap,
26+
const ParameterMap& parameterMap) {
27+
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
28+
return false;
29+
}
30+
31+
/* the size of inputs for norm-layer is 1 */
32+
CHECK_EQ(config_.inputs_size(), 1UL);
33+
const NormConfig& conf = config_.inputs(0).norm_conf();
34+
localSize_ = conf.size();
35+
alpha_ = conf.scale();
36+
beta_ = conf.pow();
37+
38+
ic_ = conf.channels();
39+
oc_ = ic_;
40+
iw_ = conf.img_size();
41+
ow_ = conf.output_x();
42+
ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
43+
oh_ = conf.has_output_y() ? conf.output_y() : conf.output_x();
44+
CHECK_EQ(iw_, ow_);
45+
CHECK_EQ(ih_, oh_);
46+
return true;
47+
}
48+
49+
void MKLDNNLRNLayer::reshape(
50+
int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
51+
CHECK_EQ(inputLayers_.size(), 1UL);
52+
reshapeInput(bs, ih, iw);
53+
// ic_ and oc can not be changed
54+
CHECK_EQ((size_t)ic,
55+
inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw)
56+
<< "Input channel can not be changed";
57+
oh = ih;
58+
ow = iw;
59+
reshapeOutput(oh, ow);
60+
resizeOutput(bs, oc * oh * ow);
61+
}
62+
63+
void MKLDNNLRNLayer::resetFwd(std::vector<primitive>& pipeline,
64+
std::vector<MKLDNNMatrixPtr>& inputs,
65+
MKLDNNMatrixPtr& out) {
66+
resetFwdBuffers(inputs[0], out);
67+
68+
resetFwdPD(fwdPD_, inputs[0], out);
69+
70+
resetFwdPipeline(pipeline, fwdPD_, inputs[0], out);
71+
}
72+
73+
void MKLDNNLRNLayer::resetBwd(std::vector<primitive>& pipeline,
74+
std::vector<MKLDNNMatrixPtr>& inputs,
75+
MKLDNNMatrixPtr& out) {
76+
std::shared_ptr<lrn_bwd::primitive_desc> pd;
77+
78+
resetBwdBuffers(inputs[0], out);
79+
80+
resetBwdPD(pd, inputs[0], out);
81+
82+
resetBwdPipeline(pipeline, pd, inputs[0], out);
83+
}
84+
85+
void MKLDNNLRNLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
86+
MKLDNNMatrixPtr& out) {
87+
resetInValue(in);
88+
CHECK(in);
89+
resetOutValue(out, in->getPrimitiveDesc());
90+
}
91+
92+
void MKLDNNLRNLayer::resetFwdPD(std::shared_ptr<lrn_fwd::primitive_desc>& pd,
93+
MKLDNNMatrixPtr in,
94+
MKLDNNMatrixPtr out) {
95+
prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
96+
: prop_kind::forward_training;
97+
auto fwdDesc = lrn_fwd::desc(pk,
98+
algorithm::lrn_across_channels,
99+
in->getMemoryDesc(),
100+
localSize_,
101+
alpha_,
102+
beta_,
103+
1.0f);
104+
pd.reset(new lrn_fwd::primitive_desc(fwdDesc, engine_));
105+
// prepare workspace if necessary
106+
workspace_ =
107+
passType_ != PASS_TEST
108+
? std::make_shared<memory>(memory(pd->workspace_primitive_desc()))
109+
: nullptr;
110+
}
111+
112+
void MKLDNNLRNLayer::resetFwdPipeline(
113+
std::vector<primitive>& pipeline,
114+
std::shared_ptr<lrn_fwd::primitive_desc>& pd,
115+
MKLDNNMatrixPtr& in,
116+
MKLDNNMatrixPtr& out) {
117+
fwd_ = workspace_
118+
? std::make_shared<lrn_fwd>(lrn_fwd(*pd, *in, *workspace_, *out))
119+
: std::make_shared<lrn_fwd>(lrn_fwd(*pd, *in, *out));
120+
pipeline.push_back(*fwd_);
121+
}
122+
123+
void MKLDNNLRNLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
124+
MKLDNNMatrixPtr& out) {
125+
CHECK(inVals_[0] && outVal_);
126+
resetOutGrad(out, outVal_->getPrimitiveDesc());
127+
resetInGrad(in, inVals_[0]->getPrimitiveDesc());
128+
}
129+
130+
void MKLDNNLRNLayer::resetBwdPD(std::shared_ptr<lrn_bwd::primitive_desc>& pd,
131+
MKLDNNMatrixPtr& in,
132+
MKLDNNMatrixPtr& out) {
133+
pd = nullptr;
134+
if (in == nullptr) {
135+
return;
136+
}
137+
CHECK(out);
138+
auto bwdDesc = lrn_bwd::desc(algorithm::lrn_across_channels,
139+
in->getMemoryDesc(),
140+
out->getMemoryDesc(),
141+
localSize_,
142+
alpha_,
143+
beta_,
144+
1.0f);
145+
pd.reset(new lrn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
146+
}
147+
148+
void MKLDNNLRNLayer::resetBwdPipeline(
149+
std::vector<primitive>& pipeline,
150+
std::shared_ptr<lrn_bwd::primitive_desc>& pd,
151+
MKLDNNMatrixPtr& in,
152+
MKLDNNMatrixPtr& out) {
153+
if (pd == nullptr) {
154+
return;
155+
}
156+
CHECK(inVals_[0]);
157+
CHECK(workspace_);
158+
bwdData_ = std::make_shared<lrn_bwd>(
159+
lrn_bwd(*pd, *inVals_[0], *out, *workspace_, *in));
160+
pipeline.push_back(*bwdData_);
161+
}
162+
163+
} // namespace paddle
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "MKLDNNLayer.h"
18+
#include "mkldnn.hpp"
19+
20+
namespace paddle {
21+
typedef mkldnn::lrn_forward lrn_fwd;
22+
typedef mkldnn::lrn_backward lrn_bwd;
23+
24+
/**
25+
* @brief A subclass of MKLDNNLayer LRN(Local Response Norm) layer.
26+
*
27+
* The config file api is mkldnn_lrn
28+
*/
29+
class MKLDNNLRNLayer : public MKLDNNLayer {
30+
protected:
31+
// save forward primitive_desc, which can be used in backward
32+
std::shared_ptr<lrn_fwd::primitive_desc> fwdPD_;
33+
// according to https://github.com/01org/mkl-dnn/blob/master/tests/gtests/
34+
// test_lrn_backward.cpp, lrn need workspace for backward
35+
std::shared_ptr<mkldnn::memory> workspace_;
36+
37+
int localSize_;
38+
float alpha_, beta_; // scale and pow in paddle
39+
40+
public:
41+
explicit MKLDNNLRNLayer(const LayerConfig& config) : MKLDNNLayer(config) {}
42+
43+
~MKLDNNLRNLayer() {}
44+
45+
bool init(const LayerMap& layerMap,
46+
const ParameterMap& parameterMap) override;
47+
48+
void reshape(
49+
int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
50+
51+
void resetFwd(std::vector<mkldnn::primitive>& pipeline,
52+
std::vector<MKLDNNMatrixPtr>& inputs,
53+
MKLDNNMatrixPtr& out) override;
54+
55+
void resetBwd(std::vector<mkldnn::primitive>& pipeline,
56+
std::vector<MKLDNNMatrixPtr>& inputs,
57+
MKLDNNMatrixPtr& out) override;
58+
59+
protected:
60+
void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
61+
void resetFwdPD(std::shared_ptr<lrn_fwd::primitive_desc>& pd,
62+
MKLDNNMatrixPtr in,
63+
MKLDNNMatrixPtr out);
64+
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
65+
std::shared_ptr<lrn_fwd::primitive_desc>& pd,
66+
MKLDNNMatrixPtr& in,
67+
MKLDNNMatrixPtr& out);
68+
void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
69+
void resetBwdPD(std::shared_ptr<lrn_bwd::primitive_desc>& pd,
70+
MKLDNNMatrixPtr& in,
71+
MKLDNNMatrixPtr& out);
72+
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
73+
std::shared_ptr<lrn_bwd::primitive_desc>& pd,
74+
MKLDNNMatrixPtr& in,
75+
MKLDNNMatrixPtr& out);
76+
};
77+
78+
} // namespace paddle

0 commit comments

Comments
 (0)