Skip to content

Commit b9dfe8e

Browse files
author
Haonan
authored
Merge pull request #1231 from yu239/rotate_and_flip
One bug fix and two new features
2 parents 9763761 + 73dcf2c commit b9dfe8e

22 files changed

+435
-45
lines changed

paddle/cuda/include/hl_matrix.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,4 +267,16 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
267267
const int dimN,
268268
real scale);
269269

270+
/**
271+
* @brief Matrix rotation in 90 degrees
272+
*
273+
* @param[in] mat input matrix (M x N).
274+
* @param[out] matRot output matrix (N x M).
275+
* @param[in] dimM input matrix height.
276+
* @param[in] dimN input matrix width.
277+
* @param[in] clockWise rotation direction
278+
*/
279+
extern void hl_matrix_rotate(
280+
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
281+
270282
#endif /* HL_MATRIX_H_ */

paddle/cuda/include/stub/hl_matrix_stub.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,8 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
106106
const int dimM,
107107
const int dimN,
108108
real scale) {}
109+
110+
inline void hl_matrix_rotate(
111+
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
112+
109113
#endif // HL_MATRIX_STUB_H_

paddle/cuda/src/hl_cuda_matrix.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,28 @@ void hl_matrix_collect_shared_bias(real* B_d,
840840
(B_d, A_d, channel, dimM, dimN, dim, limit, scale);
841841
CHECK_SYNC("hl_matrix_collect_shared_bias failed");
842842
}
843+
844+
__global__ void keMatrixRotate(real* mat, real* matRot,
845+
int dimM, int dimN, bool clockWise) {
846+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
847+
if (idx < dimM * dimN) {
848+
int i = idx / dimN;
849+
int j = idx % dimN;
850+
if (clockWise) {
851+
matRot[j * dimM + i] = mat[(dimM - i - 1) * dimN + j];
852+
} else {
853+
matRot[j * dimM + i] = mat[i * dimN + (dimN - j - 1)];
854+
}
855+
}
856+
}
857+
858+
void hl_matrix_rotate(real *mat, real* matRot,
859+
int dimM, int dimN, bool clockWise) {
860+
CHECK_NOTNULL(mat);
861+
CHECK_NOTNULL(matRot);
862+
const int threads = 512;
863+
const int blocks = DIVUP(dimM * dimN, threads);
864+
keMatrixRotate<<< blocks, threads, 0, STREAM_DEFAULT >>>
865+
(mat, matRot, dimM, dimN, clockWise);
866+
CHECK_SYNC("hl_matrix_rotate failed");
867+
}

paddle/gserver/layers/FeatureMapExpandLayer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ void FeatureMapExpandLayer::forward(PassType passType) {
9595

9696
void FeatureMapExpandLayer::backward(const UpdateCallback& callback) {
9797
MatrixPtr inGrad = getInputGrad(0);
98+
if (NULL == inGrad) {
99+
return;
100+
}
98101
MatrixPtr outGrad = getOutputGrad();
99102
size_t batchSize = getInput(0).getBatchSize();
100103
int imgSize = inGrad->getWidth();
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "RotateLayer.h"
16+
17+
namespace paddle {
18+
19+
REGISTER_LAYER(rotate, RotateLayer);
20+
21+
bool RotateLayer::init(const LayerMap& layerMap,
22+
const ParameterMap& parameterMap) {
23+
Layer::init(layerMap, parameterMap);
24+
25+
CHECK_EQ(inputLayers_.size(), 1UL);
26+
height_ = config_.height();
27+
width_ = config_.width();
28+
CHECK_GT(height_, 0);
29+
CHECK_GT(width_, 0);
30+
return true;
31+
}
32+
33+
void RotateLayer::forward(PassType passType) {
34+
Layer::forward(passType);
35+
36+
MatrixPtr input = getInputValue(0);
37+
batchSize_ = input->getHeight();
38+
size_ = input->getWidth();
39+
CHECK_GE(size_, height_ * width_);
40+
CHECK_EQ(size_ % (height_ * width_), 0)
41+
<< "total size_ is not dividable by (height_ * width_), i.e., "
42+
<< "channel number should be an integer";
43+
channels_ = size_ / (height_ * width_);
44+
45+
resizeOutput(batchSize_, size_);
46+
47+
MatrixPtr outV = getOutputValue();
48+
for (int b = 0; b < batchSize_; b++) { // for each input feat map
49+
for (int c = 0; c < channels_; c++) { // for each feat channel
50+
MatrixPtr inputSample =
51+
Matrix::create(input->getData() + b * size_ + c * height_ * width_,
52+
height_,
53+
width_,
54+
false,
55+
useGpu_);
56+
MatrixPtr outputSample =
57+
Matrix::create(outV->getData() + b * size_ + c * height_ * width_,
58+
width_,
59+
height_,
60+
false,
61+
useGpu_);
62+
inputSample->rotate(outputSample, false, true /* clock-wise */);
63+
}
64+
}
65+
66+
if (getInputGrad(0)) {
67+
zeroGrad();
68+
}
69+
}
70+
71+
void RotateLayer::backward(const UpdateCallback& callback) {
72+
(void)callback;
73+
74+
MatrixPtr outputGrad = getOutputGrad();
75+
if (outputGrad == NULL) {
76+
return;
77+
}
78+
// the grad should be rotated in the reverse direction
79+
MatrixPtr preGrad = getInputGrad(0);
80+
81+
for (int b = 0; b < batchSize_; b++) { // for each input feat map
82+
for (int c = 0; c < channels_; c++) { // for each feat channel
83+
MatrixPtr inputSampleGrad =
84+
Matrix::create(preGrad->getData() + b * size_ + c * height_ * width_,
85+
height_,
86+
width_,
87+
false,
88+
useGpu_);
89+
MatrixPtr outputSampleGrad = Matrix::create(
90+
outputGrad->getData() + b * size_ + c * height_ * width_,
91+
width_,
92+
height_,
93+
false,
94+
useGpu_);
95+
MatrixPtr tmpGrad = nullptr;
96+
outputSampleGrad->rotate(tmpGrad, true, false /* anti clock-wise */);
97+
inputSampleGrad->add(*tmpGrad);
98+
}
99+
}
100+
}
101+
102+
} // namespace paddle
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "Layer.h"
18+
#include "paddle/math/Matrix.h"
19+
20+
namespace paddle {
21+
/**
22+
* A layer for rotating a multi-channel feature map (M x N x C) in the spatial
23+
* domain
24+
* The rotation is 90 degrees in clock-wise for each channel
25+
* \f[
26+
* y(j,i,:) = x(M-i-1,j,:)
27+
* \f]
28+
* where \f$x\f$ is (M x N x C) input, and \f$y\f$ is (N x M x C) output.
29+
*
30+
* The config file api is rotate_layer
31+
*
32+
*/
33+
34+
class RotateLayer : public Layer {
35+
public:
36+
explicit RotateLayer(const LayerConfig& config) : Layer(config) {}
37+
38+
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
39+
40+
void forward(PassType passType);
41+
void backward(const UpdateCallback& callback = nullptr);
42+
43+
private:
44+
int batchSize_;
45+
int size_;
46+
int height_;
47+
int width_;
48+
int channels_;
49+
};
50+
51+
} // namespace paddle

paddle/gserver/layers/TransLayer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020

2121
namespace paddle {
2222
/**
23-
* A layer for transposition.
23+
* A layer for transposing a minibatch matrix.
2424
* \f[
2525
y = x^\mathrm{T}
2626
* \f]

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,25 @@ TEST(Layer, ResizeLayer) {
13161316
}
13171317
}
13181318

1319+
TEST(Layer, RotateLayer) {
1320+
TestConfig config;
1321+
config.biasSize = 0;
1322+
config.layerConfig.set_type("rotate");
1323+
const int CHANNEL = 2;
1324+
const int HEIGHT = 8;
1325+
const int WIDTH = 4;
1326+
const int INPUT_SIZE = HEIGHT * WIDTH * CHANNEL;
1327+
config.layerConfig.set_size(INPUT_SIZE);
1328+
config.layerConfig.set_height(HEIGHT);
1329+
config.layerConfig.set_width(WIDTH);
1330+
config.inputDefs.push_back({INPUT_DATA, "layer_0", INPUT_SIZE, 0});
1331+
config.layerConfig.add_inputs();
1332+
1333+
for (auto useGpu : {false, true}) {
1334+
testLayerGrad(config, "rotate", 100, false, useGpu);
1335+
}
1336+
}
1337+
13191338
TEST(Layer, NCELayer) {
13201339
TestConfig config;
13211340
size_t numClasses = 4;

paddle/math/CpuSparseMatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ MatrixPtr CpuSparseMatrix::subMatrix(size_t startRow, size_t numRows) {
372372
}
373373

374374
/* mem MUST be alloced outside (memAlloc=false) */
375-
void CpuSparseMatrix::transpose(MatrixPtr matTrans, bool memAlloc) {
375+
void CpuSparseMatrix::transpose(MatrixPtr& matTrans, bool memAlloc) {
376376
CHECK(!memAlloc);
377377
CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(matTrans.get());
378378
if (format_ == SPARSE_CSR) {

paddle/math/CpuSparseMatrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class CpuSparseMatrix : public Matrix {
201201
void zeroMem();
202202

203203
/// mem MUST be alloced outside (memAlloc=false)
204-
void transpose(MatrixPtr matTrans, bool memAlloc);
204+
void transpose(MatrixPtr& matTrans, bool memAlloc);
205205

206206
void mul(const Matrix& A, const Matrix& B, real alpha, real beta);
207207

0 commit comments

Comments
 (0)