Skip to content

Commit 19d8146

Browse files
authored
Merge pull request #358 from yu239/multi_binary_cross_entropy
multi_binary_cross_entropy when ids vector is provided
2 parents 58e1b3b + 5591292 commit 19d8146

File tree

11 files changed

+299
-18
lines changed

11 files changed

+299
-18
lines changed

paddle/cuda/include/hl_matrix.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d,
126126
int dimM,
127127
int dimN);
128128

129+
/**
130+
* @brief Matrix multi-binary label cross entropy
131+
*
132+
* @param[in] output input matrix (M x N).
133+
* @param[out] entropy output matrix (M x 1).
134+
* @param[in] mat input sparse matrix.
135+
* @param[in] dimM matrix height.
136+
* @param[in] dimN matrix width.
137+
*/
138+
extern void hl_matrix_multi_binary_cross_entropy(real* output,
139+
real* entropy,
140+
hl_sparse_matrix_s mat,
141+
int dimM,
142+
int dimN);
143+
144+
/**
145+
* @brief Matrix multi-binary label cross entropy backprop
146+
*
147+
* @param[in] output input matrix (M x N).
148+
* @param[out] grad output matrix (M x N).
149+
* @param[in] mat input sparse matrix.
150+
* @param[in] dimM matrix height.
151+
* @param[in] dimN matrix width.
152+
*/
153+
extern void hl_matrix_multi_binary_cross_entropy_bp(real* output,
154+
real* grad,
155+
hl_sparse_matrix_s mat,
156+
int dimM,
157+
int dimN);
158+
129159
/**
130160
* @brief Matrix zero memory.
131161
*

paddle/cuda/include/stub/hl_matrix_stub.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d,
5757
int dimM,
5858
int dimN) {}
5959

60+
inline void hl_matrix_multi_binary_cross_entropy(real* output,
61+
real* entropy,
62+
hl_sparse_matrix_s mat,
63+
int dimM,
64+
int dimN) {}
65+
66+
inline void hl_matrix_multi_binary_cross_entropy_bp(real* output,
67+
real* grad,
68+
hl_sparse_matrix_s mat,
69+
int dimM,
70+
int dimN) {}
71+
6072
inline void hl_matrix_zero_mem(real* data, int num) {}
6173

6274
inline void hl_param_relu_forward(real* output,

paddle/cuda/src/hl_cuda_matrix.cu

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "hl_matrix_ops.cuh"
1919
#include "hl_matrix_apply.cuh"
2020
#include "hl_sequence.h"
21+
#include "hl_sparse.ph"
2122
#include "paddle/utils/Logging.h"
2223
#include "hl_device_functions.cuh"
2324
#include "hl_gpu_matrix_kernel.cuh"
@@ -317,6 +318,85 @@ void hl_matrix_classification_error(real* A_d,
317318
CHECK_SYNC("hl_matrix_classification_error");
318319
}
319320

321+
__global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
322+
real* entropy,
323+
int* row,
324+
int* col,
325+
int dimM,
326+
int dimN) {
327+
int index = blockIdx.x * blockDim.x + threadIdx.x;
328+
if (index < dimM) {
329+
for (int i = 0; i < dimN; i ++) {
330+
entropy[index] -= log(1 - output[index * dimN + i]);
331+
}
332+
int *row_col = col + row[index];
333+
int col_num = row[index + 1] - row[index];
334+
for (int i = 0; i < col_num; i ++) {
335+
real o = output[index * dimN + row_col[i]];
336+
entropy[index] -= log(o / (1 - o));
337+
}
338+
}
339+
}
340+
341+
void hl_matrix_multi_binary_cross_entropy(real* output,
342+
real* entropy,
343+
hl_sparse_matrix_s csr_mat,
344+
int dimM,
345+
int dimN) {
346+
CHECK_NOTNULL(output);
347+
CHECK_NOTNULL(entropy);
348+
CHECK_NOTNULL(csr_mat);
349+
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
350+
int n_threads = 1024;
351+
int blocks = (dimM + n_threads - 1) / n_threads;
352+
dim3 threads(n_threads);
353+
dim3 grid(blocks);
354+
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
355+
KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>>
356+
(output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
357+
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed");
358+
}
359+
360+
__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output,
361+
real* grad,
362+
int* row,
363+
int* col,
364+
int dimM,
365+
int dimN) {
366+
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
367+
if (row_idx < dimM) {
368+
for (int i = 0; i < dimN; i ++) {
369+
int index = row_idx * dimN + i;
370+
grad[index] += 1.0 / (1 - output[index]);
371+
}
372+
int col_num = row[row_idx + 1] - row[row_idx];
373+
int *row_col = col + row[row_idx];
374+
for (int i = 0; i < col_num; i ++) {
375+
int index = row_idx * dimN + row_col[i];
376+
grad[index] -= 1.0 / (output[index] * (1 - output[index]));
377+
}
378+
}
379+
}
380+
381+
void hl_matrix_multi_binary_cross_entropy_bp(real* output,
382+
real* grad,
383+
hl_sparse_matrix_s csr_mat,
384+
int dimM,
385+
int dimN) {
386+
CHECK_NOTNULL(output);
387+
CHECK_NOTNULL(grad);
388+
CHECK_NOTNULL(csr_mat);
389+
CHECK_EQ(csr_mat->format, HL_SPARSE_CSR);
390+
int n_threads = 1024;
391+
int blocks = (dimM + n_threads - 1) / n_threads;
392+
dim3 threads(n_threads);
393+
dim3 grid(blocks);
394+
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
395+
KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>>
396+
(output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
397+
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed");
398+
}
399+
320400
__global__ void KeMatrixCrossEntropy(real* O,
321401
real* E,
322402
int* label,
@@ -685,7 +765,7 @@ __global__ void KeMatrixAddSharedBias(real* A,
685765
int dim = N / channel;
686766
if (index < M * N) {
687767
int i = index % N;
688-
i = i / dim;
768+
i = i / dim;
689769
A[index] += scale * B[i];
690770
}
691771
}
@@ -713,7 +793,7 @@ __global__ void KeMatrixCollectSharedBias(real *B,
713793
const int dim,
714794
const int limit,
715795
real scale) {
716-
if (dim < limit) {
796+
if (dim < limit) {
717797
int index = blockIdx.x * blockDim.x + threadIdx.x;
718798
if (index < channel) {
719799
real sum = 0.0;

paddle/gserver/layers/CostLayer.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,25 +462,43 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
462462

463463
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
464464
Matrix& target) {
465-
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
466-
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
467-
target.multiBinaryLabelCrossEntropy(output, *label.value);
465+
MatrixPtr value = nullptr;
466+
if (label.ids) {
467+
CHECK(!label.value);
468+
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
469+
} else {
470+
CHECK(label.value);
471+
value = label.value;
472+
}
473+
474+
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
475+
dynamic_cast<GpuSparseMatrix*>(value.get())) {
476+
target.multiBinaryLabelCrossEntropy(output, *value);
468477
} else {
469478
Matrix::resizeOrCreate(targetPerDim_, output.getHeight(), output.getWidth(),
470479
false, useGpu_);
471480

472-
targetPerDim_->binaryLabelCrossEntropy(output, *label.value);
481+
targetPerDim_->binaryLabelCrossEntropy(output, *value);
473482
targetPerDim_->rowSum(target);
474483
}
475484
}
476485

477486
void MultiBinaryLabelCrossEntropy::backwardImp(
478487
Matrix& output, Argument& label, Matrix& outputG) {
479-
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
480-
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
481-
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value);
488+
MatrixPtr value = nullptr;
489+
if (label.ids) {
490+
CHECK(!value);
491+
value = label.ids->toOneHotSparseMatrix(output.getWidth(), useGpu_);
492+
} else {
493+
CHECK(label.value);
494+
value = label.value;
495+
}
496+
497+
if (dynamic_cast<CpuSparseMatrix*>(value.get()) ||
498+
dynamic_cast<GpuSparseMatrix*>(value.get())) {
499+
outputG.multiBinaryLabelCrossEntropyBp(output, *value);
482500
} else {
483-
outputG.binaryLabelCrossEntropyBp(output, *label.value);
501+
outputG.binaryLabelCrossEntropyBp(output, *value);
484502
}
485503
}
486504

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ TEST(Layer, multi_cross) {
528528
}
529529
}
530530

531-
TEST(Layer, multi_binary_label) {
531+
TEST(Layer, multi_binary_label_sparse_mat) {
532532
TestConfig config;
533533
config.layerConfig.set_type("multi_binary_label_cross_entropy");
534534
config.biasSize = 0;
@@ -538,9 +538,26 @@ TEST(Layer, multi_binary_label) {
538538
config.layerConfig.add_inputs();
539539
config.layerConfig.add_inputs();
540540

541-
// Not support GPU now
542-
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
543-
/* trans */ false, /* useGpu */ false);
541+
for (auto useGpu : {false, true}) {
542+
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
543+
/* trans */ false, useGpu);
544+
}
545+
}
546+
547+
TEST(layer, multi_binary_label_id) {
548+
TestConfig config;
549+
config.layerConfig.set_type("multi_binary_label_cross_entropy");
550+
config.biasSize = 0;
551+
552+
config.inputDefs.push_back({INPUT_DATA, "layer_0", 50, 0});
553+
config.inputDefs.push_back({INPUT_LABEL, "layer_1", 10, 0});
554+
config.layerConfig.add_inputs();
555+
config.layerConfig.add_inputs();
556+
557+
for (auto useGpu : {false, true}) {
558+
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
559+
/* trans */ false, useGpu);
560+
}
544561
}
545562

546563
TEST(Layer, multi_cross_with_selfnorm) {

paddle/math/CpuSparseMatrix.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,6 @@ void CpuSparseMatrix::setRow(size_t row, size_t colNum,
409409
if (format_ == SPARSE_CSR) {
410410
CHECK_LT(row, height_);
411411
CHECK(NULL != cols);
412-
for (size_t i = row; i < height_; i++) {
413-
CHECK_EQ(rows_[i + 1], rows_[i]);
414-
}
415412
if (0 == row) {
416413
rows_[row] = 0;
417414
}

paddle/math/Matrix.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
12681268
}
12691269
}
12701270

1271+
void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
1272+
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
1273+
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
1274+
1275+
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
1276+
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
1277+
CHECK(height_ == outputPtr->height_ && width_ == 1
1278+
&& outputPtr->width_ == labelPtr->getWidth()
1279+
&& outputPtr->height_ == labelPtr->getHeight())
1280+
<< "Matrix dimensions are not equal";
1281+
1282+
real* output_d = outputPtr->data_;
1283+
real* entropy_d = data_;
1284+
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
1285+
hl_matrix_multi_binary_cross_entropy(
1286+
output_d, entropy_d, mat_d, height_, outputPtr->width_);
1287+
}
1288+
1289+
void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) {
1290+
GpuMatrix* outputPtr = dynamic_cast<GpuMatrix*>(&output);
1291+
auto labelPtr = dynamic_cast<GpuSparseMatrix*>(&label);
1292+
1293+
CHECK(outputPtr && labelPtr) << "Invalid argument pointer";
1294+
CHECK(labelPtr->format_ == SPARSE_CSR) << "Matrix format not supported";
1295+
CHECK(height_ == outputPtr->height_ && width_ == outputPtr->width_
1296+
&& outputPtr->width_ == labelPtr->getWidth()
1297+
&& outputPtr->height_ == labelPtr->getHeight())
1298+
<< "Matrix dimensions are not equal";
1299+
1300+
real* output_d = outputPtr->data_;
1301+
real* grad_d = data_;
1302+
hl_sparse_matrix_s mat_d = labelPtr->sMatrix_.get();
1303+
hl_matrix_multi_binary_cross_entropy_bp(
1304+
output_d, grad_d, mat_d, height_, width_);
1305+
}
1306+
12711307
/**
12721308
* CpuMatrix
12731309
*/

paddle/math/Matrix.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,10 @@ class GpuMatrix : public Matrix {
13031303
const size_t numChannels,
13041304
const real ratioH,
13051305
const real ratioW);
1306+
1307+
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
1308+
1309+
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
13061310
};
13071311

13081312
class CpuMatrix : public Matrix {

paddle/math/Vector.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/utils/ThreadLocal.h"
2222
#include "paddle/utils/Thread.h"
2323
#include "paddle/utils/Flags.h"
24+
#include "Matrix.h"
2425
#include "hl_gpu.h"
2526
#include "hl_table_apply.h"
2627

@@ -73,6 +74,31 @@ std::shared_ptr<VectorT<T>> VectorT<T>::create(size_t size,
7374
}
7475
}
7576

77+
template <>
78+
MatrixPtr VectorT<real>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
79+
LOG(FATAL) << "Wrong for real vector";
80+
return nullptr;
81+
}
82+
83+
template <>
84+
MatrixPtr VectorT<int>::toOneHotSparseMatrix(size_t idRange, bool useGpu) {
85+
int height = getSize();
86+
int width = idRange;
87+
MatrixPtr mat = Matrix::createSparseMatrix(
88+
height, idRange, height, NO_VALUE, SPARSE_CSR, false, useGpu);
89+
90+
CpuIVector cpuIds(height);
91+
cpuIds.copyFrom(*this);
92+
int *idData = cpuIds.getData();
93+
94+
for (int i = 0; i < height; i ++) {
95+
const unsigned int id = idData[i];
96+
CHECK_LT(id, width);
97+
mat->setRow(i, 1, &id, nullptr);
98+
}
99+
return mat;
100+
}
101+
76102
template <class T>
77103
GpuVectorT<T>::GpuVectorT(size_t size)
78104
: VectorT<T>(size, std::make_shared<GpuMemoryHandle>(sizeof(T) * size),

paddle/math/Vector.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class BaseVector;
3737

3838
class SyncThreadPool;
3939

40+
class Matrix;
41+
4042
template<class T>
4143
class BaseVector : public BaseMatrixT<T> {
4244
public:
@@ -155,6 +157,12 @@ class VectorT : public BaseVector<T> {
155157
subVecFrom(src, interval.first, interval.second - interval.first);
156158
}
157159

160+
/**
161+
* convert the vector to a sparse one_hot matrix of width idRange
162+
* only applies to IVector
163+
*/
164+
std::shared_ptr<Matrix> toOneHotSparseMatrix(size_t idRange, bool useGpu);
165+
158166
/**
159167
* This function will crash if the size of src and dest is different.
160168
*/

0 commit comments

Comments
 (0)