Skip to content

Commit 74a699a

Browse files
committed
change clone to resizeOrCreate in fm layer
1 parent 13ec6f9 commit 74a699a

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

paddle/gserver/layers/FactorizationMachineLayer.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,30 @@ void FactorizationMachineLayer::forward(PassType passType) {
5858
inputMulFactor_, batchSize, factorSize_, false, useGpu_);
5959
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
6060

61-
REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str());
61+
REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str());
6262
inputMulFactor_->mul(*inputV, *latentVectors_->getW());
6363
inputMulFactor_->square2(*tmpOut_);
6464
outV->sumRows(*tmpOut_, 0.5, 0);
6565

66-
inputSquare_ = inputV->clone(0, 0, useGpu_);
67-
if (dynamic_cast<CpuSparseMatrix*>(inputSquare_.get())) {
66+
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
67+
Matrix::resizeOrCreateSparseMatrix(inputSquare_,
68+
inputV->getHeight(),
69+
inputV->getWidth(),
70+
inputV->getElementCnt(),
71+
inputV->getValueType());
6872
inputSquare_->copyFrom(*inputV);
6973
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
7074
} else {
75+
Matrix::resizeOrCreate(
76+
inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
7177
inputV->square2(*inputSquare_);
7278
}
7379
latentVectors_->getW()->square2(*latentVectorsSquare_);
7480
tmpOut_->mul(*inputSquare_, *latentVectorsSquare_);
7581
outV->sumRows(*tmpOut_, -0.5, 1.0);
7682

7783
/* activation */ {
78-
REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str());
84+
REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str());
7985
forwardActivation();
8086
}
8187
}

0 commit comments

Comments
 (0)