Skip to content

Commit 069d000

Browse files
author
Haonan
committed
multi_binary_cross_entropy when ids vector is provided
1 parent ef5e483 commit 069d000

File tree

10 files changed

+263
-4
lines changed

10 files changed

+263
-4
lines changed

paddle/cuda/include/hl_matrix.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,36 @@ extern void hl_matrix_cross_entropy_bp(real* grad_d,
126126
int dimM,
127127
int dimN);
128128

129+
/**
130+
* @brief Matrix multi-binary label cross entropy
131+
*
132+
* @param[in] output input matrix (M x N).
133+
* @param[out] entropy output matrix (M x 1).
134+
* @param[in] mat input sparse matrix.
135+
* @param[in] dimM matrix height.
136+
* @param[in] dimN matrix width.
137+
*/
138+
extern void hl_matrix_multi_binary_cross_entropy(real* output,
139+
real* entropy,
140+
hl_sparse_matrix_s mat,
141+
int dimM,
142+
int dimN);
143+
144+
/**
145+
* @brief Matrix multi-binary label cross entropy backprop
146+
*
147+
* @param[in] output input matrix (M x N).
148+
* @param[out] grad output matrix (M x N).
149+
* @param[in] mat input sparse matrix.
150+
* @param[in] dimM matrix height.
151+
* @param[in] dimN matrix width.
152+
*/
153+
extern void hl_matrix_multi_binary_cross_entropy_bp(real* output,
154+
real* grad,
155+
hl_sparse_matrix_s mat,
156+
int dimM,
157+
int dimN);
158+
129159
/**
130160
* @brief Matrix zero memory.
131161
*

paddle/cuda/include/stub/hl_matrix_stub.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ inline void hl_matrix_cross_entropy_bp(real* grad_d,
5757
int dimM,
5858
int dimN) {}
5959

60+
inline void hl_matrix_multi_binary_cross_entropy(real* output,
61+
real* entropy,
62+
hl_sparse_matrix_s mat,
63+
int dimM,
64+
int dimN) {}
65+
66+
inline void hl_matrix_multi_binary_cross_entropy_bp(real* output,
67+
real* grad,
68+
hl_sparse_matrix_s mat,
69+
int dimM,
70+
int dimN) {}
71+
6072
inline void hl_matrix_zero_mem(real* data, int num) {}
6173

6274
inline void hl_param_relu_forward(real* output,

paddle/cuda/src/hl_cuda_matrix.cu

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "hl_matrix_ops.cuh"
1919
#include "hl_matrix_apply.cuh"
2020
#include "hl_sequence.h"
21+
#include "hl_sparse.ph"
2122
#include "paddle/utils/Logging.h"
2223
#include "hl_device_functions.cuh"
2324
#include "hl_gpu_matrix_kernel.cuh"
@@ -317,6 +318,83 @@ void hl_matrix_classification_error(real* A_d,
317318
CHECK_SYNC("hl_matrix_classification_error");
318319
}
319320

321+
__global__ void KeMatrixMultiBinaryCrossEntropy(real* output,
322+
real* entropy,
323+
int* row,
324+
int* col,
325+
int dimM,
326+
int dimN) {
327+
int index = blockIdx.x * blockDim.x + threadIdx.x;
328+
if (index < dimM) {
329+
for (int i = 0; i < dimN; i ++) {
330+
entropy[index] -= log(1 - output[index * dimN + i]);
331+
}
332+
int *row_col = col + row[index];
333+
int col_num = row[index + 1] - row[index];
334+
for (int i = 0; i < col_num; i ++) {
335+
real o = output[index * dimN + row_col[i]];
336+
entropy[index] -= log(o / (1 - o));
337+
}
338+
}
339+
}
340+
341+
void hl_matrix_multi_binary_cross_entropy(real* output,
342+
real* entropy,
343+
hl_sparse_matrix_s csr_mat,
344+
int dimM,
345+
int dimN) {
346+
CHECK_NOTNULL(output);
347+
CHECK_NOTNULL(entropy);
348+
CHECK_NOTNULL(csr_mat);
349+
int n_threads = 1024;
350+
int blocks = (dimM + n_threads - 1) / n_threads;
351+
dim3 threads(n_threads);
352+
dim3 grid(blocks);
353+
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
354+
KeMatrixMultiBinaryCrossEntropy<<< grid, threads, 0, STREAM_DEFAULT >>>
355+
(output, entropy, mat->csr_row, mat->csr_col, dimM, dimN);
356+
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy failed");
357+
}
358+
359+
__global__ void KeMatrixMultiBinaryCrossEntropyBp(real* output,
360+
real* grad,
361+
int* row,
362+
int* col,
363+
int dimM,
364+
int dimN) {
365+
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
366+
if (row_idx < dimM) {
367+
for (int i = 0; i < dimN; i ++) {
368+
int index = row_idx * dimN + i;
369+
grad[index] += 1.0 / (1 - output[index]);
370+
}
371+
int col_num = row[row_idx + 1] - row[row_idx];
372+
int *row_col = col + row[row_idx];
373+
for (int i = 0; i < col_num; i ++) {
374+
int index = row_idx * dimN + row_col[i];
375+
grad[index] -= 1.0 / (output[index] * (1 - output[index]));
376+
}
377+
}
378+
}
379+
380+
void hl_matrix_multi_binary_cross_entropy_bp(real* output,
381+
real* grad,
382+
hl_sparse_matrix_s csr_mat,
383+
int dimM,
384+
int dimN) {
385+
CHECK_NOTNULL(output);
386+
CHECK_NOTNULL(grad);
387+
CHECK_NOTNULL(csr_mat);
388+
int n_threads = 1024;
389+
int blocks = (dimM + n_threads - 1) / n_threads;
390+
dim3 threads(n_threads);
391+
dim3 grid(blocks);
392+
hl_csr_matrix mat = (hl_csr_matrix)(csr_mat->matrix);
393+
KeMatrixMultiBinaryCrossEntropyBp<<< grid, threads, 0, STREAM_DEFAULT >>>
394+
(output, grad, mat->csr_row, mat->csr_col, dimM, dimN);
395+
CHECK_SYNC("hl_matrix_multi_binary_cross_entropy_bp failed");
396+
}
397+
320398
__global__ void KeMatrixCrossEntropy(real* O,
321399
real* E,
322400
int* label,

paddle/gserver/layers/CostLayer.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,8 @@ bool MultiBinaryLabelCrossEntropy::init(const LayerMap& layerMap,
462462

463463
void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
464464
Matrix& target) {
465+
label.idsToSparseMatrix(output.getWidth(), useGpu_);
466+
465467
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
466468
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
467469
target.multiBinaryLabelCrossEntropy(output, *label.value);
@@ -476,6 +478,8 @@ void MultiBinaryLabelCrossEntropy::forwardImp(Matrix& output, Argument& label,
476478

477479
void MultiBinaryLabelCrossEntropy::backwardImp(
478480
Matrix& output, Argument& label, Matrix& outputG) {
481+
label.idsToSparseMatrix(output.getWidth(), useGpu_);
482+
479483
if (dynamic_cast<CpuSparseMatrix*>(label.value.get()) ||
480484
dynamic_cast<GpuSparseMatrix*>(label.value.get())) {
481485
outputG.multiBinaryLabelCrossEntropyBp(output, *label.value);

paddle/gserver/tests/test_LayerGrad.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,10 @@ TEST(Layer, multi_binary_label) {
538538
config.layerConfig.add_inputs();
539539
config.layerConfig.add_inputs();
540540

541-
// Not support GPU now
542-
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
543-
/* trans */ false, /* useGpu */ false);
541+
for (auto useGpu : {false, true}) {
542+
testLayerGrad(config, "multi_binary_label_cross_entropy", 100,
543+
/* trans */ false, useGpu);
544+
}
544545
}
545546

546547
TEST(Layer, multi_cross_with_selfnorm) {

paddle/math/Matrix.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,42 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
12681268
}
12691269
}
12701270

1271+
void GpuMatrix::multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label) {
1272+
GpuMatrix* output_ptr = dynamic_cast<GpuMatrix*>(&output);
1273+
auto label_ptr = dynamic_cast<GpuSparseMatrix*>(&label);
1274+
1275+
CHECK(output_ptr && label_ptr) << "Invalid argument pointer";
1276+
CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported";
1277+
CHECK(height_ == output_ptr->height_ && width_ == 1
1278+
&& output_ptr->width_ == label_ptr->getWidth()
1279+
&& output_ptr->height_ == label_ptr->getHeight())
1280+
<< "Matrix dimensions are not equal";
1281+
1282+
real* output_d = output_ptr->data_;
1283+
real* entropy_d = data_;
1284+
hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get();
1285+
hl_matrix_multi_binary_cross_entropy(
1286+
output_d, entropy_d, mat_d, height_, output_ptr->width_);
1287+
}
1288+
1289+
void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix &output, Matrix &label) {
1290+
GpuMatrix* output_ptr = dynamic_cast<GpuMatrix*>(&output);
1291+
auto label_ptr = dynamic_cast<GpuSparseMatrix*>(&label);
1292+
1293+
CHECK(output_ptr && label_ptr) << "Invalid argument pointer";
1294+
CHECK(label_ptr->format_ == SPARSE_CSR) << "Matrix format not supported";
1295+
CHECK(height_ == output_ptr->height_ && width_ == output_ptr->width_
1296+
&& output_ptr->width_ == label_ptr->getWidth()
1297+
&& output_ptr->height_ == label_ptr->getHeight())
1298+
<< "Matrix dimensions are not equal";
1299+
1300+
real* output_d = output_ptr->data_;
1301+
real* grad_d = data_;
1302+
hl_sparse_matrix_s mat_d = label_ptr->sMatrix_.get();
1303+
hl_matrix_multi_binary_cross_entropy_bp(
1304+
output_d, grad_d, mat_d, height_, width_);
1305+
}
1306+
12711307
/**
12721308
* CpuMatrix
12731309
*/

paddle/math/Matrix.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,10 @@ class GpuMatrix : public Matrix {
13031303
const size_t numChannels,
13041304
const real ratioH,
13051305
const real ratioW);
1306+
1307+
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
1308+
1309+
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
13061310
};
13071311

13081312
class CpuMatrix : public Matrix {

paddle/math/tests/test_matrixCompare.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2208,7 +2208,6 @@ void testCollectSharedBias(int numSamples, int dim, int channel) {
22082208
MatrixCheckErr(*cpuBias, *check);
22092209
}
22102210

2211-
22122211
TEST(Matrix, sharedBias) {
22132212
for (auto numSamples : {1, 100, 520}) {
22142213
for (auto dim : {100 * 16, 100 * 32}) {
@@ -2222,6 +2221,71 @@ TEST(Matrix, sharedBias) {
22222221
}
22232222
}
22242223

2224+
void testMultiBinaryLabelCrossEntropy(int numSamples, int dim) {
2225+
MatrixPtr output = std::make_shared<CpuMatrix>(numSamples, dim);
2226+
MatrixPtr cpuOutput = std::make_shared<CpuMatrix>(numSamples, dim);
2227+
MatrixPtr gpuOutput = std::make_shared<GpuMatrix>(numSamples, dim);
2228+
2229+
MatrixPtr cpuEntropy = std::make_shared<CpuMatrix>(numSamples, 1);
2230+
MatrixPtr gpuEntropy = std::make_shared<GpuMatrix>(numSamples, 1);
2231+
2232+
MatrixPtr cpuGrad = std::make_shared<CpuMatrix>(numSamples, dim);
2233+
MatrixPtr gpuGrad = std::make_shared<GpuMatrix>(numSamples, dim);
2234+
2235+
auto cpuRows = IVector::create(numSamples + 1, false);
2236+
auto cpuCols = IVector::create(numSamples, false);
2237+
auto gpuRows = IVector::create(numSamples + 1, true);
2238+
auto gpuCols = IVector::create(numSamples, true);
2239+
cpuRows->setElement(0, 0);
2240+
gpuRows->setElement(0, 0);
2241+
for (int i = 0; i < numSamples; i ++) {
2242+
int id = rand() % dim; // NOLINT
2243+
cpuRows->setElement(i + 1, i + 1);
2244+
gpuRows->setElement(i + 1, i + 1);
2245+
cpuCols->setElement(i, id);
2246+
gpuCols->setElement(i, id);
2247+
}
2248+
2249+
MatrixPtr cpuLabel = std::make_shared<CpuSparseMatrix>
2250+
(nullptr, cpuRows->getData(), cpuCols->getData(),
2251+
numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
2252+
MatrixPtr gpuLabel = std::make_shared<GpuSparseMatrix>
2253+
(nullptr, gpuRows->getData(), gpuCols->getData(),
2254+
numSamples, dim, numSamples, NO_VALUE, SPARSE_CSR, false);
2255+
2256+
output->randomizeUniform();
2257+
cpuOutput->zeroMem();
2258+
output->softmax(*cpuOutput);
2259+
gpuOutput->copyFrom(*cpuOutput);
2260+
2261+
cpuEntropy->zeroMem();
2262+
gpuEntropy->zeroMem();
2263+
cpuEntropy->multiBinaryLabelCrossEntropy(*cpuOutput, *cpuLabel);
2264+
gpuEntropy->multiBinaryLabelCrossEntropy(*gpuOutput, *gpuLabel);
2265+
2266+
MatrixPtr check1 = std::make_shared<CpuMatrix>(numSamples, 1);
2267+
check1->copyFrom(*gpuEntropy);
2268+
MatrixCheckErr(*cpuEntropy, *check1);
2269+
2270+
cpuGrad->zeroMem();
2271+
gpuGrad->zeroMem();
2272+
cpuGrad->multiBinaryLabelCrossEntropyBp(*cpuOutput, *cpuLabel);
2273+
gpuGrad->multiBinaryLabelCrossEntropyBp(*gpuOutput, *gpuLabel);
2274+
2275+
MatrixPtr check2 = std::make_shared<CpuMatrix>(numSamples, dim);
2276+
check2->copyFrom(*gpuGrad);
2277+
MatrixCheckErr(*cpuGrad, *check2);
2278+
}
2279+
2280+
TEST(Matrix, multiBinaryCrossEntropy) {
2281+
for (auto numSamples : {1, 100, 500}) {
2282+
for (auto dim : {1000, 10000, 100000}) {
2283+
VLOG(3) << " numSamples=" << numSamples << " dim=" << dim;
2284+
testMultiBinaryLabelCrossEntropy(numSamples, dim);
2285+
}
2286+
}
2287+
}
2288+
22252289
int main(int argc, char** argv) {
22262290
testing::InitGoogleTest(&argc, argv);
22272291
initMain(argc, argv);

paddle/parameter/Argument.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,26 @@ void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
572572
}
573573
}
574574

575+
void Argument::idsToSparseMatrix(int width, bool useGpu) {
576+
if (ids) {
577+
CHECK(!value);
578+
int height = ids->getSize();
579+
int nnz = height;
580+
auto rows = IVector::create(height + 1, useGpu);
581+
auto cols = IVector::create(nnz, useGpu);
582+
rows->setElement(0, 0);
583+
for (int i = 0; i < height; i ++) {
584+
int id = ids->getElement(i);
585+
CHECK_LT(id, width);
586+
rows->setElement(i + 1, i + 1);
587+
cols->setElement(i, id);
588+
}
589+
value = Matrix::createSparseMatrix(
590+
nullptr, rows->getData(), cols->getData(),
591+
height, width, nnz, NO_VALUE, SPARSE_CSR, false, useGpu);
592+
} else {
593+
CHECK(value);
594+
}
595+
}
596+
575597
} // namespace paddle

paddle/parameter/Argument.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,14 @@ struct Argument {
286286
sequence has sub-sequence degrades to a sequence.
287287
*/
288288
void degradeSequence(const Argument& input, bool useGpu);
289+
290+
/*
291+
@brief convert the ids vector to value as a sparse matrix
292+
the ids vector keeps valid
293+
@param the matrix width (id range)
294+
@useGpu
295+
*/
296+
void idsToSparseMatrix(int width, bool useGpu);
289297
};
290298

291299
} // namespace paddle

0 commit comments

Comments
 (0)