Skip to content

Commit 1c31bb9

Browse files
authored
Merge pull request #5543 from tensor-tang/ds2
add resize of MKLDNNMatrix
2 parents 23b9bc0 + e1b8f5f commit 1c31bb9

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

paddle/math/MKLDNNMatrix.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,7 @@ void MKLDNNMatrix::downSpatial() {
152152
}
153153
memory::desc md = memory::desc(dstDims, getDtype(), dstFmt);
154154
memory::primitive_desc pd = memory::primitive_desc(md, getEngine());
155-
mkldnn_primitive_t result;
156-
mkldnn::error::wrap_c_api(
157-
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
158-
"could not create a memory primitive");
159-
reset(result);
160-
set_data_handle(data_);
155+
resetMKLDNNMemory(pd, data_);
161156
}
162157

163158
} // namespace paddle

paddle/math/MKLDNNMatrix.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,27 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
145145
m_.reset();
146146
}
147147

148+
/**
149+
* override the CpuMatrix::resize
150+
*/
151+
void resize(size_t newHeight, size_t newWidth) override {
152+
m_->resize(newHeight, newWidth);
153+
if (data_ == m_->getData() && elementCnt_ == newHeight * newWidth) {
154+
return;
155+
}
156+
CpuMatrix::setData(data_);
157+
height_ = newHeight;
158+
width_ = newWidth;
159+
elementCnt_ = newHeight * newWidth;
160+
stride_ = width_;
161+
auto pd = mkldnn::memory::primitive_desc(
162+
mkldnn::memory::desc({(int)newHeight, (int)newWidth},
163+
getDtype(),
164+
mkldnn::memory::format::nc),
165+
getEngine());
166+
resetMKLDNNMemory(pd, data_);
167+
}
168+
148169
/**
149170
* override Matrix::getData
150171
* check data before return
@@ -215,6 +236,17 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
215236
memory::format srcFmt,
216237
memory::format dstFmt,
217238
memory::dims dm);
239+
/**
240+
* reset this MKLDNN Memory from primitve desc
241+
*/
242+
void resetMKLDNNMemory(memory::primitive_desc pd, real* data) {
243+
mkldnn_primitive_t result;
244+
mkldnn::error::wrap_c_api(
245+
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
246+
"could not create a memory primitive");
247+
reset(result);
248+
set_data_handle(data);
249+
}
218250

219251
private:
220252
// save the CpuMatrixPtr in case the buffer released outside

0 commit comments

Comments
 (0)