Skip to content

Commit 3bce32b

Browse files
author
gaoyuan
committed
Add create matrix pointer funtion
1 parent 17c697c commit 3bce32b

File tree

2 files changed

+42
-39
lines changed

2 files changed

+42
-39
lines changed

paddle/gserver/layers/CrossChannelNormLayer.cpp

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ limitations under the License. */
1919

2020
namespace paddle {
2121

22+
MatrixPtr CrossChannelNormLayer::createSampleMatrix(MatrixPtr data,
23+
size_t iter,
24+
size_t spatialDim) {
25+
return Matrix::create(data->getData() + iter * channels_ * spatialDim,
26+
channels_,
27+
spatialDim,
28+
false,
29+
useGpu_);
30+
}
31+
32+
MatrixPtr CrossChannelNormLayer::createSpatialMatrix(MatrixPtr data,
33+
size_t iter,
34+
size_t spatialDim) {
35+
return Matrix::create(
36+
data->getData() + iter * spatialDim, 1, spatialDim, false, useGpu_);
37+
}
38+
2239
void CrossChannelNormLayer::forward(PassType passType) {
2340
Layer::forward(passType);
2441
MatrixPtr inV = getInputValue(0);
@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) {
4057
normBuffer_->addScalar(*normBuffer_, 1e-6);
4158
inV->square2(*dataBuffer_);
4259
for (size_t i = 0; i < batchSize; i++) {
43-
MatrixPtr inTmp = Matrix::create(
44-
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
45-
MatrixPtr dataTmp = Matrix::create(dataBuffer_->getData() + i * dataDim,
46-
channels_,
47-
spatialDim,
48-
false,
49-
useGpu_);
50-
MatrixPtr outTmp = Matrix::create(
51-
outV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
52-
MatrixPtr normTmp = Matrix::create(
53-
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
60+
const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
61+
const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
62+
MatrixPtr outVTmp = createSampleMatrix(outV, i, spatialDim);
63+
MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
64+
5465
// compute norm.
55-
spatialBuffer_->sumCols(*dataTmp, 1, 1);
66+
spatialBuffer_->sumCols(*dataTmp, 1, 0);
5667
spatialBuffer_->sqrt2(*spatialBuffer_);
5768
normTmp->copyFrom(*spatialBuffer_);
58-
outTmp->copyFrom(*inTmp);
59-
outTmp->divRowVector(*spatialBuffer_);
69+
outVTmp->copyFrom(*inVTmp);
70+
outVTmp->divRowVector(*spatialBuffer_);
6071
// scale the layer.
61-
outTmp->mulColVector(*scale_->getW());
72+
outVTmp->mulColVector(*scale_->getW());
6273
}
6374
}
6475

@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
7889
Matrix::resizeOrCreate(sampleBuffer_, channels_, spatialDim, false, useGpu_);
7990
scaleDiff_->zeroMem();
8091
for (size_t i = 0; i < batchSize; i++) {
81-
// propagate to param.
82-
MatrixPtr dataBufferTmp =
83-
Matrix::create(dataBuffer_->getData() + i * dataDim,
84-
channels_,
85-
spatialDim,
86-
false,
87-
useGpu_);
88-
const MatrixPtr inValueTmp = Matrix::create(
89-
inV->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
90-
const MatrixPtr outGradTmp = Matrix::create(
91-
outG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
92-
MatrixPtr inGradTmp = Matrix::create(
93-
inG->getData() + i * dataDim, channels_, spatialDim, false, useGpu_);
94-
const MatrixPtr normTmp = Matrix::create(
95-
normBuffer_->getData() + i * spatialDim, 1, spatialDim, false, useGpu_);
96-
channelBuffer_->sumRows(*dataBufferTmp, 1, 1);
92+
MatrixPtr outGTmp = createSampleMatrix(outG, i, spatialDim);
93+
const MatrixPtr dataTmp = createSampleMatrix(dataBuffer_, i, spatialDim);
94+
const MatrixPtr inVTmp = createSampleMatrix(inV, i, spatialDim);
95+
const MatrixPtr inGTmp = createSampleMatrix(inG, i, spatialDim);
96+
const MatrixPtr normTmp = createSpatialMatrix(normBuffer_, i, spatialDim);
97+
98+
channelBuffer_->sumRows(*dataTmp, 1, 0);
9799
channelBuffer_->dotDiv(*channelBuffer_, *(scale_->getW()));
98100
// store a / scale[i] in scaleDiff_ temporary
99101
scaleDiff_->add(*channelBuffer_, 1.);
100102

101-
sampleBuffer_->dotMul(*inValueTmp, *outGradTmp);
103+
sampleBuffer_->dotMul(*inVTmp, *outGTmp);
102104
spatialBuffer_->sumCols(*sampleBuffer_, 1., 1.);
103105
// scale the grad
104-
inGradTmp->copyFrom(*inValueTmp);
105-
inGradTmp->mulRowVector(*spatialBuffer_);
106+
inGTmp->copyFrom(*inVTmp);
107+
inGTmp->mulRowVector(*spatialBuffer_);
106108
// divide by square of norm
107109
spatialBuffer_->dotMul(*normTmp, *normTmp);
108-
inGradTmp->divRowVector(*spatialBuffer_);
110+
inGTmp->divRowVector(*spatialBuffer_);
109111
// subtract
110-
inGradTmp->add(*outGradTmp, -1, 1);
112+
inGTmp->add(*outGTmp, -1, 1);
111113
// divide by norm
112-
inGradTmp->divRowVector(*normTmp);
114+
inGTmp->divRowVector(*normTmp);
113115
// scale the diff
114-
inGradTmp->mulColVector(*scale_->getW());
116+
inGTmp->mulColVector(*scale_->getW());
115117
}
116118
// updata scale
117119
if (scale_->getWGrad()) scale_->getWGrad()->copyFrom(*scaleDiff_);

paddle/gserver/layers/NormLayer.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ class CrossChannelNormLayer : public NormLayer {
8080
explicit CrossChannelNormLayer(const LayerConfig& config)
8181
: NormLayer(config) {}
8282
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
83-
8483
void forward(PassType passType);
8584
void backward(const UpdateCallback& callback);
85+
MatrixPtr createSampleMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
86+
MatrixPtr createSpatialMatrix(MatrixPtr data, size_t iter, size_t spatialDim);
8687

8788
protected:
8889
size_t channels_;

0 commit comments

Comments
 (0)