Skip to content

Commit 2ed7982

Browse files
authored
Merge pull request #13327 from kbinias/kbinias/conv-weights-converted-once
[MKLDNN] Reusing once reordered convolution weights in test mode
2 parents b681537 + accdecc commit 2ed7982

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
130130

131131
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
132132
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
133-
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
133+
std::vector<mkldnn::primitive>& pipeline, // NOLINT
134+
bool is_persistent = false) {
134135
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
135136
auto weights_pd = conv_pd_->weights_primitive_desc();
136137
return this->AcquireMemory(weights_pd, user_weights_pd,
137138
user_weights_memory_p, "@weights_mem_p",
138-
pipeline);
139+
pipeline, is_persistent);
139140
}
140141

141142
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
266267
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
267268
"It must use CPUPlace.");
268269

270+
const bool is_test = ctx.Attr<bool>("is_test");
271+
269272
auto& dev_ctx =
270273
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
271274
const auto& mkldnn_engine = dev_ctx.GetEngine();
@@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
371374
auto src_memory_p =
372375
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
373376
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
374-
user_weights_memory_p, pipeline);
377+
user_weights_memory_p, pipeline, is_test);
375378
auto dst_memory_p =
376379
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
377380

paddle/fluid/operators/conv_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
109109
}
110110

111111
void Conv2DOpMaker::Make() {
112+
AddAttr<bool>("is_test", "").SetDefault(false);
112113
AddInput(
113114
"Input",
114115
"(Tensor) The input tensor of convolution operator. "

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ class MKLDNNHandler {
192192
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
193193
const std::shared_ptr<mkldnn::memory> user_memory_p,
194194
const std::string& suffix,
195-
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
195+
std::vector<mkldnn::primitive>& pipeline, // NOLINT
196+
bool is_persistent = false) {
196197
// create reorder primitive if the input format is not the preferred one
197198
auto local_key = key_ + suffix;
198199
auto key_reorder_p = key_ + suffix + "reorder_p";
@@ -213,7 +214,7 @@ class MKLDNNHandler {
213214
pipeline.push_back(*reorder_p);
214215
}
215216
dev_ctx_.SetBlob(local_key, target_memory_p);
216-
} else {
217+
} else if (!is_persistent) {
217218
// Make reorder if needed
218219
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
219220
dev_ctx_.GetBlob(key_reorder_p));

0 commit comments

Comments
 (0)