@@ -58,24 +58,30 @@ void FactorizationMachineLayer::forward(PassType passType) {
58
58
inputMulFactor_, batchSize, factorSize_, false , useGpu_);
59
59
Matrix::resizeOrCreate (tmpOut_, batchSize, factorSize_, false , useGpu_);
60
60
61
- REGISTER_TIMER_INFO (" InputMulFactorTimer " , getName ().c_str ());
61
+ REGISTER_TIMER_INFO (" FmInputMulFactorTimer " , getName ().c_str ());
62
62
inputMulFactor_->mul (*inputV, *latentVectors_->getW ());
63
63
inputMulFactor_->square2 (*tmpOut_);
64
64
outV->sumRows (*tmpOut_, 0.5 , 0 );
65
65
66
- inputSquare_ = inputV->clone (0 , 0 , useGpu_);
67
- if (dynamic_cast <CpuSparseMatrix*>(inputSquare_.get ())) {
66
+ if (dynamic_cast <CpuSparseMatrix*>(inputV.get ())) {
67
+ Matrix::resizeOrCreateSparseMatrix (inputSquare_,
68
+ inputV->getHeight (),
69
+ inputV->getWidth (),
70
+ inputV->getElementCnt (),
71
+ inputV->getValueType ());
68
72
inputSquare_->copyFrom (*inputV);
69
73
(dynamic_cast <CpuSparseMatrix*>(inputSquare_.get ()))->square2 ();
70
74
} else {
75
+ Matrix::resizeOrCreate (
76
+ inputSquare_, inputV->getHeight (), inputV->getWidth (), false , useGpu_);
71
77
inputV->square2 (*inputSquare_);
72
78
}
73
79
latentVectors_->getW ()->square2 (*latentVectorsSquare_);
74
80
tmpOut_->mul (*inputSquare_, *latentVectorsSquare_);
75
81
outV->sumRows (*tmpOut_, -0.5 , 1.0 );
76
82
77
83
/* activation */ {
78
- REGISTER_TIMER_INFO (" FmAtvTimer " , getName ().c_str ());
84
+ REGISTER_TIMER_INFO (" FmFwAtvTimer " , getName ().c_str ());
79
85
forwardActivation ();
80
86
}
81
87
}
0 commit comments