Skip to content

Commit 992ac8f

Browse files
author
Liang Zhao
committed
Implement setDiag() with BaseMatrix::assign()
1 parent 8c40bfd commit 992ac8f

File tree

2 files changed

+10
-21
lines changed

2 files changed

+10
-21
lines changed

paddle/math/Matrix.cpp

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
187187
trans_, useGpu_);
188188
}
189189

190+
void Matrix::setDiag(real value) {
191+
CHECK(data_ != NULL);
192+
CHECK_EQ(height_, width_);
193+
194+
zeroMem();
195+
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
196+
diag.assign(value);
197+
}
198+
190199
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
191200
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
192201
height, width, trans, true) {}
@@ -203,16 +212,6 @@ void GpuMatrix::resetOne() {
203212
one();
204213
}
205214

206-
void GpuMatrix::setDiag(real value) {
207-
CHECK(data_ != NULL);
208-
CHECK_EQ(height_, width_);
209-
210-
zeroMem();
211-
for (size_t i = 0; i < height_; i++) {
212-
hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real));
213-
}
214-
}
215-
216215
void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
217216
size_t newSize = newHeight * newWidth;
218217
if (NULL == memoryHandle_.get() ||
@@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() {
12551254
BaseMatrix::one();
12561255
}
12571256

1258-
void CpuMatrix::setDiag(real value) {
1259-
CHECK(data_ != NULL);
1260-
CHECK_EQ(height_, width_);
1261-
1262-
zeroMem();
1263-
for (size_t i = 0; i < height_; i++) {
1264-
data_[i * stride_ + i] = value;
1265-
}
1266-
}
1267-
12681257
void CpuMatrix::copyFrom(const Matrix& src) {
12691258
CHECK(isContiguous());
12701259
if (typeid(src) == typeid(GpuMatrix)) {

paddle/math/Matrix.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class Matrix : public BaseMatrix {
195195

196196
virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
197197

198-
virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; }
198+
void setDiag(real value);
199199

200200
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
201201

0 commit comments

Comments
 (0)