@@ -19,6 +19,23 @@ limitations under the License. */
19
19
20
20
namespace paddle {
21
21
22
+ MatrixPtr CrossChannelNormLayer::createSampleMatrix (MatrixPtr data,
23
+ size_t iter,
24
+ size_t spatialDim) {
25
+ return Matrix::create (data->getData () + iter * channels_ * spatialDim,
26
+ channels_,
27
+ spatialDim,
28
+ false ,
29
+ useGpu_);
30
+ }
31
+
32
+ MatrixPtr CrossChannelNormLayer::createSpatialMatrix (MatrixPtr data,
33
+ size_t iter,
34
+ size_t spatialDim) {
35
+ return Matrix::create (
36
+ data->getData () + iter * spatialDim, 1 , spatialDim, false , useGpu_);
37
+ }
38
+
22
39
void CrossChannelNormLayer::forward (PassType passType) {
23
40
Layer::forward (passType);
24
41
MatrixPtr inV = getInputValue (0 );
@@ -40,25 +57,19 @@ void CrossChannelNormLayer::forward(PassType passType) {
40
57
normBuffer_->addScalar (*normBuffer_, 1e-6 );
41
58
inV->square2 (*dataBuffer_);
42
59
for (size_t i = 0 ; i < batchSize; i++) {
43
- MatrixPtr inTmp = Matrix::create (
44
- inV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
45
- MatrixPtr dataTmp = Matrix::create (dataBuffer_->getData () + i * dataDim,
46
- channels_,
47
- spatialDim,
48
- false ,
49
- useGpu_);
50
- MatrixPtr outTmp = Matrix::create (
51
- outV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
52
- MatrixPtr normTmp = Matrix::create (
53
- normBuffer_->getData () + i * spatialDim, 1 , spatialDim, false , useGpu_);
60
+ const MatrixPtr inVTmp = createSampleMatrix (inV, i, spatialDim);
61
+ const MatrixPtr dataTmp = createSampleMatrix (dataBuffer_, i, spatialDim);
62
+ MatrixPtr outVTmp = createSampleMatrix (outV, i, spatialDim);
63
+ MatrixPtr normTmp = createSpatialMatrix (normBuffer_, i, spatialDim);
64
+
54
65
// compute norm.
55
- spatialBuffer_->sumCols (*dataTmp, 1 , 1 );
66
+ spatialBuffer_->sumCols (*dataTmp, 1 , 0 );
56
67
spatialBuffer_->sqrt2 (*spatialBuffer_);
57
68
normTmp->copyFrom (*spatialBuffer_);
58
- outTmp ->copyFrom (*inTmp );
59
- outTmp ->divRowVector (*spatialBuffer_);
69
+ outVTmp ->copyFrom (*inVTmp );
70
+ outVTmp ->divRowVector (*spatialBuffer_);
60
71
// scale the layer.
61
- outTmp ->mulColVector (*scale_->getW ());
72
+ outVTmp ->mulColVector (*scale_->getW ());
62
73
}
63
74
}
64
75
@@ -78,40 +89,31 @@ void CrossChannelNormLayer::backward(const UpdateCallback& callback) {
78
89
Matrix::resizeOrCreate (sampleBuffer_, channels_, spatialDim, false , useGpu_);
79
90
scaleDiff_->zeroMem ();
80
91
for (size_t i = 0 ; i < batchSize; i++) {
81
- // propagate to param.
82
- MatrixPtr dataBufferTmp =
83
- Matrix::create (dataBuffer_->getData () + i * dataDim,
84
- channels_,
85
- spatialDim,
86
- false ,
87
- useGpu_);
88
- const MatrixPtr inValueTmp = Matrix::create (
89
- inV->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
90
- const MatrixPtr outGradTmp = Matrix::create (
91
- outG->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
92
- MatrixPtr inGradTmp = Matrix::create (
93
- inG->getData () + i * dataDim, channels_, spatialDim, false , useGpu_);
94
- const MatrixPtr normTmp = Matrix::create (
95
- normBuffer_->getData () + i * spatialDim, 1 , spatialDim, false , useGpu_);
96
- channelBuffer_->sumRows (*dataBufferTmp, 1 , 1 );
92
+ MatrixPtr outGTmp = createSampleMatrix (outG, i, spatialDim);
93
+ const MatrixPtr dataTmp = createSampleMatrix (dataBuffer_, i, spatialDim);
94
+ const MatrixPtr inVTmp = createSampleMatrix (inV, i, spatialDim);
95
+ const MatrixPtr inGTmp = createSampleMatrix (inG, i, spatialDim);
96
+ const MatrixPtr normTmp = createSpatialMatrix (normBuffer_, i, spatialDim);
97
+
98
+ channelBuffer_->sumRows (*dataTmp, 1 , 0 );
97
99
channelBuffer_->dotDiv (*channelBuffer_, *(scale_->getW ()));
98
100
// store a / scale[i] in scaleDiff_ temporary
99
101
scaleDiff_->add (*channelBuffer_, 1 .);
100
102
101
- sampleBuffer_->dotMul (*inValueTmp , *outGradTmp );
103
+ sampleBuffer_->dotMul (*inVTmp , *outGTmp );
102
104
spatialBuffer_->sumCols (*sampleBuffer_, 1 ., 1 .);
103
105
// scale the grad
104
- inGradTmp ->copyFrom (*inValueTmp );
105
- inGradTmp ->mulRowVector (*spatialBuffer_);
106
+ inGTmp ->copyFrom (*inVTmp );
107
+ inGTmp ->mulRowVector (*spatialBuffer_);
106
108
// divide by square of norm
107
109
spatialBuffer_->dotMul (*normTmp, *normTmp);
108
- inGradTmp ->divRowVector (*spatialBuffer_);
110
+ inGTmp ->divRowVector (*spatialBuffer_);
109
111
// subtract
110
- inGradTmp ->add (*outGradTmp , -1 , 1 );
112
+ inGTmp ->add (*outGTmp , -1 , 1 );
111
113
// divide by norm
112
- inGradTmp ->divRowVector (*normTmp);
114
+ inGTmp ->divRowVector (*normTmp);
113
115
// scale the diff
114
- inGradTmp ->mulColVector (*scale_->getW ());
116
+ inGTmp ->mulColVector (*scale_->getW ());
115
117
}
116
118
// updata scale
117
119
if (scale_->getWGrad ()) scale_->getWGrad ()->copyFrom (*scaleDiff_);
0 commit comments