@@ -145,6 +145,27 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
145
145
m_.reset ();
146
146
}
147
147
148
+ /* *
149
+ * override the CpuMatrix::resize
150
+ */
151
+ void resize (size_t newHeight, size_t newWidth) override {
152
+ m_->resize (newHeight, newWidth);
153
+ if (data_ == m_->getData () && elementCnt_ == newHeight * newWidth) {
154
+ return ;
155
+ }
156
+ CpuMatrix::setData (data_);
157
+ height_ = newHeight;
158
+ width_ = newWidth;
159
+ elementCnt_ = newHeight * newWidth;
160
+ stride_ = width_;
161
+ auto pd = mkldnn::memory::primitive_desc (
162
+ mkldnn::memory::desc ({(int )newHeight, (int )newWidth},
163
+ getDtype (),
164
+ mkldnn::memory::format::nc),
165
+ getEngine ());
166
+ resetMKLDNNMemory (pd, data_);
167
+ }
168
+
148
169
/* *
149
170
* override Matrix::getData
150
171
* check data before return
@@ -215,6 +236,17 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
215
236
memory::format srcFmt,
216
237
memory::format dstFmt,
217
238
memory::dims dm);
239
+ /* *
240
+ * reset this MKLDNN Memory from primitve desc
241
+ */
242
+ void resetMKLDNNMemory (memory::primitive_desc pd, real* data) {
243
+ mkldnn_primitive_t result;
244
+ mkldnn::error::wrap_c_api (
245
+ mkldnn_primitive_create (&result, pd.get (), nullptr , nullptr ),
246
+ " could not create a memory primitive" );
247
+ reset (result);
248
+ set_data_handle (data);
249
+ }
218
250
219
251
private:
220
252
// save the CpuMatrixPtr in case the buffer released outside
0 commit comments