Skip to content

Commit 6fed6f2

Browse files
committed
Add support of sparse_binary_vector as input for fm layer
1 parent 5392a50 commit 6fed6f2

File tree

3 files changed

+34
-13
lines changed

3 files changed

+34
-13
lines changed

paddle/gserver/layers/FactorizationMachineLayer.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
9696

9797
/* Calculate the gradients of the latentVectors_ matrix */
9898
if (latentVectors_->getWGrad()) {
99-
MatrixPtr tmpInput = inputV->clone(0, 0, useGpu_);
10099
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
100+
Matrix::resizeOrCreateSparseMatrix(tmpInput_,
101+
inputV->getHeight(),
102+
inputV->getWidth(),
103+
inputV->getElementCnt());
104+
101105
CpuSparseMatrix* sparseInputV =
102106
dynamic_cast<CpuSparseMatrix*>(inputV.get());
103107
CpuSparseMatrix* sparseInputSquare =
104108
dynamic_cast<CpuSparseMatrix*>(inputSquare_.get());
105109
CpuSparseMatrix* sparseTmpInput =
106-
dynamic_cast<CpuSparseMatrix*>(tmpInput.get());
110+
dynamic_cast<CpuSparseMatrix*>(tmpInput_.get());
107111
sparseTmpInput->copyFrom(*sparseInputV);
112+
108113
sparseTmpInput->rowScale(0, *sparseInputV, *oGrad);
109114
latentVectors_->getWGrad()->mul(
110115
*sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1);
@@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
115120
negOnes_->add(-1);
116121
tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0);
117122
} else {
118-
tmpInput->rowScale(0, *inputV, *oGrad);
123+
Matrix::resizeOrCreate(
124+
tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
125+
126+
tmpInput_->rowScale(0, *inputV, *oGrad);
119127
latentVectors_->getWGrad()->mul(
120-
*tmpInput->getTranspose(), *inputMulFactor_, 1, 1);
121-
tmpInput->rowScale(0, *inputSquare_, *oGrad);
128+
*tmpInput_->getTranspose(), *inputMulFactor_, 1, 1);
129+
tmpInput_->rowScale(0, *inputSquare_, *oGrad);
122130

123-
tmpSum_->sumCols(*tmpInput, -1, 0);
131+
tmpSum_->sumCols(*tmpInput_, -1, 0);
124132
}
125133

126134
latentVectors_->getWGrad()->addRowScale(

paddle/gserver/layers/FactorizationMachineLayer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class FactorizationMachineLayer : public Layer {
6161
// Store temporary calculation result
6262
MatrixPtr tmpOut_;
6363
MatrixPtr tmpSum_;
64+
MatrixPtr tmpInput_;
6465
// Negative identity matrix
6566
MatrixPtr negOnes_;
6667

paddle/math/CpuSparseMatrix.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) {
266266
CHECK_EQ(width_, b.getWidth());
267267
real* A = getValue();
268268
real* B = b.getValue();
269-
for (size_t i = 0; i < height_; i++) {
270-
size_t start = getRowStartIdx(i);
271-
size_t end = getRowStartIdx(i + 1);
272-
CHECK_EQ(start, b.getRowStartIdx(i));
273-
CHECK_EQ(end, b.getRowStartIdx(i + 1));
274-
for (size_t j = start; j < end; j++) {
275-
A[j] = B[j] * c.getElement(i, cCol);
269+
if (b.getValueType() == FLOAT_VALUE) {
270+
for (size_t i = 0; i < height_; i++) {
271+
size_t start = getRowStartIdx(i);
272+
size_t end = getRowStartIdx(i + 1);
273+
CHECK_EQ(start, b.getRowStartIdx(i));
274+
CHECK_EQ(end, b.getRowStartIdx(i + 1));
275+
for (size_t j = start; j < end; j++) {
276+
A[j] = B[j] * c.getElement(i, cCol);
277+
}
278+
}
279+
} else if (b.getValueType() == NO_VALUE) {
280+
for (size_t i = 0; i < height_; i++) {
281+
size_t start = getRowStartIdx(i);
282+
size_t end = getRowStartIdx(i + 1);
283+
CHECK_EQ(start, b.getRowStartIdx(i));
284+
CHECK_EQ(end, b.getRowStartIdx(i + 1));
285+
for (size_t j = start; j < end; j++) {
286+
A[j] = c.getElement(i, cCol);
287+
}
276288
}
277289
}
278290
}

0 commit comments

Comments
 (0)