Skip to content

Commit 5ccf84a

Browse files
authored
Merge pull request #383 from lzhao4ever/fix_matrix_inverse
Fix matrix inverse unittest to be more robust
2 parents 05204af + 992ac8f commit 5ccf84a

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

paddle/math/Matrix.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
187187
trans_, useGpu_);
188188
}
189189

190+
void Matrix::setDiag(real value) {
191+
CHECK(data_ != NULL);
192+
CHECK_EQ(height_, width_);
193+
194+
zeroMem();
195+
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
196+
diag.assign(value);
197+
}
198+
190199
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
191200
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
192201
height, width, trans, true) {}
@@ -202,6 +211,7 @@ void GpuMatrix::resetOne() {
202211
CHECK(data_ != NULL);
203212
one();
204213
}
214+
205215
void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
206216
size_t newSize = newHeight * newWidth;
207217
if (NULL == memoryHandle_.get() ||

paddle/math/Matrix.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class Matrix : public BaseMatrix {
195195

196196
virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
197197

198+
void setDiag(real value);
199+
198200
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
199201

200202
virtual void trimFrom(const CpuSparseMatrix& src) {
@@ -330,6 +332,7 @@ class Matrix : public BaseMatrix {
330332

331333
virtual MatrixPtr getInverse() {
332334
LOG(FATAL) << "Not implemented";
335+
return nullptr;
333336
}
334337

335338
/**
@@ -1016,6 +1019,7 @@ class GpuMatrix : public Matrix {
10161019

10171020
void zeroMem();
10181021
void resetOne();
1022+
void setDiag(real value);
10191023

10201024
void resize(size_t newHeight, size_t newWidth);
10211025
void resize(size_t newHeight, size_t newWidth,
@@ -1280,6 +1284,8 @@ class CpuMatrix : public Matrix {
12801284

12811285
void zeroMem();
12821286
void resetOne();
1287+
void setDiag(real value);
1288+
12831289
void resize(size_t newHeight, size_t newWidth);
12841290
void resize(size_t newHeight, size_t newWidth,
12851291
size_t newNnz, /* used to allocate space */

paddle/math/tests/test_matrixCompare.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -647,20 +647,23 @@ void testMatrixInverse(int height) {
647647
MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height);
648648
MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height);
649649

650+
/* Make matrix well conditioned: cpu * cpuT + Identity */
650651
cpu->randomizeUniform();
652+
MatrixPtr cpuT = cpu->getTranspose();
653+
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
654+
outputCheck->mul(cpu, cpuT);
655+
cpu->setDiag(1.0);
656+
cpu->add(*outputCheck);
657+
651658
gpu->copyFrom(*cpu);
652659
cpu->inverse(cpuI, false);
653660
gpu->inverse(gpuI, false);
654661

655-
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
656662
outputCheck->copyFrom(*gpuI);
657663
MatrixCheckErr(*cpuI, *outputCheck);
658664

659665
outputCheck->mul(cpu, cpuI);
660-
cpu->zeroMem();
661-
for (int i = 0; i < height; i++) {
662-
cpu->getRowBuf(i)[i] = 1.0;
663-
}
666+
cpu->setDiag(1.0);
664667
MatrixCheckErr(*cpu, *outputCheck);
665668
}
666669

0 commit comments

Comments
 (0)