|
| 1 | +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "MKLDNNConcatLayer.h" |
| 16 | + |
| 17 | +using namespace mkldnn; // NOLINT |
| 18 | +typedef memory::format format; |
| 19 | + |
| 20 | +namespace paddle { |
| 21 | + |
| 22 | +REGISTER_LAYER(mkldnn_concat, MKLDNNConcatLayer); |
| 23 | + |
| 24 | +bool MKLDNNConcatLayer::init(const LayerMap& layerMap, |
| 25 | + const ParameterMap& parameterMap) { |
| 26 | + if (!MKLDNNLayer::init(layerMap, parameterMap)) { |
| 27 | + return false; |
| 28 | + } |
| 29 | + CHECK_GT(inputLayers_.size(), 1UL); |
| 30 | + CHECK(!biasParameter_); |
| 31 | + return true; |
| 32 | +} |
| 33 | + |
| 34 | +void MKLDNNConcatLayer::reshape( |
| 35 | + int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { |
| 36 | + reshapeInput(bs, ih, iw); |
| 37 | + ic = inputLayers_[0]->getSize() / ih / iw; |
| 38 | + CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); |
| 39 | + CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); |
| 40 | + CHECK_GT(inputLayers_.size(), 1UL); |
| 41 | + channels_.resize(inputLayers_.size()); |
| 42 | + channels_[0] = ic; |
| 43 | + // need change the output channel, so use oc_ instead |
| 44 | + // TODO(TJ): change API, use &oc |
| 45 | + oc_ = ic; |
| 46 | + for (size_t i = 1; i < inputLayers_.size(); i++) { |
| 47 | + int batchsize, height, witdh; |
| 48 | + reshapeInput(batchsize, height, witdh, i); |
| 49 | + CHECK_EQ(bs, batchsize); |
| 50 | + CHECK_EQ(ih, height); |
| 51 | + CHECK_EQ(iw, witdh); |
| 52 | + |
| 53 | + channels_[i] = inputLayers_[i]->getSize() / height / witdh; |
| 54 | + CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize()); |
| 55 | + oc_ += channels_[i]; |
| 56 | + } |
| 57 | + oh = ih; |
| 58 | + ow = iw; |
| 59 | + reshapeOutput(oh, ow); |
| 60 | + resizeOutput(bs, oc_ * oh * ow); |
| 61 | +} |
| 62 | + |
| 63 | +void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline, |
| 64 | + MKLDNNMatrixPtr& in, |
| 65 | + MKLDNNMatrixPtr& wgt, |
| 66 | + MKLDNNMatrixPtr& bias, |
| 67 | + MKLDNNMatrixPtr& out) { |
| 68 | + resetFwdBuffers(inVals_, out); |
| 69 | + in = inVals_[0]; |
| 70 | + |
| 71 | + std::shared_ptr<concat::primitive_desc> fwdPD; |
| 72 | + resetFwdPD(fwdPD, inVals_, out); |
| 73 | + |
| 74 | + resetFwdPipeline(pipeline, fwdPD, inVals_, out); |
| 75 | +} |
| 76 | + |
| 77 | +void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline, |
| 78 | + MKLDNNMatrixPtr& in, |
| 79 | + MKLDNNMatrixPtr& wgt, |
| 80 | + MKLDNNMatrixPtr& bias, |
| 81 | + MKLDNNMatrixPtr& out) { |
| 82 | + resetBwdBuffers(inGrads_, out); |
| 83 | + in = inGrads_[0]; |
| 84 | + |
| 85 | + resetBwdPipeline(pipeline, bwds_, inGrads_, out); |
| 86 | +} |
| 87 | + |
| 88 | +void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, |
| 89 | + MKLDNNMatrixPtr& out) { |
| 90 | + inputs.resize(inputLayers_.size()); |
| 91 | + bool has8c = false, has16c = false, hasnc = false; |
| 92 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 93 | + // resetInValue will use ic_ so temporary change as current input's channel |
| 94 | + // TODO(TJ): change ic_ as vector then can remove channels_ |
| 95 | + ic_ = channels_[i]; |
| 96 | + resetInValue(inputs[i], nullptr, i); |
| 97 | + CHECK(inputs[i]); |
| 98 | + auto dm = inputs[i]->getDims(); |
| 99 | + // inputs format can be different, but ndims must equal |
| 100 | + CHECK(i == 0 || dm.size() == inputs[0]->getDims().size()); |
| 101 | + CHECK_EQ(bs_, dm[0]); |
| 102 | + CHECK_EQ(channels_[i], dm[1]); |
| 103 | + if (dm.size() > 2) { |
| 104 | + CHECK_EQ(ih_, dm[2]); |
| 105 | + CHECK_EQ(iw_, dm[3]); |
| 106 | + } |
| 107 | + if (inputs[i]->getFormat() == format::nc) { |
| 108 | + hasnc = true; |
| 109 | + } |
| 110 | + if (inputs[i]->getFormat() == format::nChw8c) { |
| 111 | + has8c = true; |
| 112 | + } |
| 113 | + if (inputs[i]->getFormat() == format::nChw16c) { |
| 114 | + has16c = true; |
| 115 | + } |
| 116 | + } |
| 117 | + // change back, ic_ always save the input 0 size |
| 118 | + ic_ = channels_[0]; |
| 119 | + |
| 120 | + format outFmt; |
| 121 | + if (has16c && oc_ % 16 == 0) { |
| 122 | + outFmt = format::nChw16c; |
| 123 | + } else if (has8c && oc_ % 8 == 0) { |
| 124 | + outFmt = format::nChw8c; |
| 125 | + } else if (hasnc) { |
| 126 | + CHECK(oh_ == 1 && ow_ == 1); |
| 127 | + outFmt = format::nc; |
| 128 | + } else { |
| 129 | + outFmt = format::nchw; |
| 130 | + } |
| 131 | + memory::dims outDims = |
| 132 | + hasnc ? memory::dims{bs_, oc_} : memory::dims{bs_, oc_, oh_, ow_}; |
| 133 | + auto outPD = MKLDNNMatrix::createPrimitiveDesc(outDims, outFmt, engine_); |
| 134 | + resetOutValue(out, outPD); |
| 135 | +} |
| 136 | + |
| 137 | +void MKLDNNConcatLayer::resetFwdPD(std::shared_ptr<concat::primitive_desc>& pd, |
| 138 | + std::vector<MKLDNNMatrixPtr>& inputs, |
| 139 | + MKLDNNMatrixPtr out) { |
| 140 | + std::vector<memory::primitive_desc> srcPDs; |
| 141 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 142 | + srcPDs.push_back(inputs[i]->getPrimitiveDesc()); |
| 143 | + } |
| 144 | + CHECK(out); |
| 145 | + pd.reset(new concat::primitive_desc(out->getMemoryDesc(), axis_, srcPDs)); |
| 146 | + CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc()); |
| 147 | +} |
| 148 | + |
| 149 | +void MKLDNNConcatLayer::resetFwdPipeline( |
| 150 | + std::vector<primitive>& pipeline, |
| 151 | + std::shared_ptr<concat::primitive_desc>& pd, |
| 152 | + std::vector<MKLDNNMatrixPtr>& inputs, |
| 153 | + MKLDNNMatrixPtr& out) { |
| 154 | + std::vector<primitive::at> srcs; |
| 155 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 156 | + srcs.push_back(*(inputs[i])); |
| 157 | + } |
| 158 | + fwd_.reset(new concat(*pd, srcs, *out)); |
| 159 | + pipeline.push_back(*fwd_); |
| 160 | +} |
| 161 | + |
| 162 | +void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, |
| 163 | + MKLDNNMatrixPtr& out) { |
| 164 | + CHECK(outVal_); |
| 165 | + resetOutGrad(out, outVal_->getPrimitiveDesc()); |
| 166 | + CHECK(out); |
| 167 | + |
| 168 | + inputs.resize(inputLayers_.size()); |
| 169 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 170 | + CHECK(inVals_[i]); |
| 171 | + // resetInGrad will use inVal_ |
| 172 | + // TODO(TJ): change move inVals_ to MKLDNNLayer ans remove inVal_ |
| 173 | + inVal_ = inVals_[i]; |
| 174 | + resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i); |
| 175 | + CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc()); |
| 176 | + } |
| 177 | + // change back, inVal_ always save the input 0 |
| 178 | + inVal_ = inVals_[0]; |
| 179 | +} |
| 180 | + |
| 181 | +void MKLDNNConcatLayer::resetBwdPipeline( |
| 182 | + std::vector<mkldnn::primitive>& pipeline, |
| 183 | + std::vector<std::shared_ptr<mkldnn::primitive>>& prims, |
| 184 | + std::vector<MKLDNNMatrixPtr>& inputs, |
| 185 | + MKLDNNMatrixPtr& out) { |
| 186 | + // reset the backward primitives |
| 187 | + memory::dims offsets = {0, 0, 0, 0}; |
| 188 | + prims.resize(inputs.size()); |
| 189 | + CHECK_EQ(inputs.size(), channels_.size()); |
| 190 | + for (size_t i = 0; i < inputs.size(); i++) { |
| 191 | + auto viewPD = view::primitive_desc( |
| 192 | + out->getPrimitiveDesc(), inputs[i]->getDims(), offsets); |
| 193 | + auto bwdPD = reorder::primitive_desc(viewPD.dst_primitive_desc(), |
| 194 | + inputs[i]->getPrimitiveDesc()); |
| 195 | + prims[i].reset(new reorder(bwdPD, *out, *(inputs[i]))); |
| 196 | + offsets[axis_] += channels_[i]; |
| 197 | + // push to pipeline |
| 198 | + pipeline.push_back(*prims[i]); |
| 199 | + } |
| 200 | +} |
| 201 | + |
| 202 | +} // namespace paddle |
0 commit comments