Skip to content

Commit 1658958

Browse files
committed
Reusing converted weights
1 parent a557608 commit 1658958

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
130130

131131
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
132132
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
133-
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
133+
const std::vector<mkldnn::primitive>& pipeline,
134+
bool is_test = false) { // NOLINT
134135
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
135136
auto weights_pd = conv_pd_->weights_primitive_desc();
136137
return this->AcquireMemory(weights_pd, user_weights_pd,
137138
user_weights_memory_p, "@weights_mem_p",
138-
pipeline);
139+
pipeline, is_test);
139140
}
140141

141142
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
266267
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
267268
"It must use CPUPlace.");
268269

270+
const bool is_test = ctx.Attr<bool>("is_test");
271+
269272
auto& dev_ctx =
270273
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
271274
const auto& mkldnn_engine = dev_ctx.GetEngine();
@@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
371374
auto src_memory_p =
372375
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
373376
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
374-
user_weights_memory_p, pipeline);
377+
user_weights_memory_p, pipeline, is_test);
375378
auto dst_memory_p =
376379
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
377380

paddle/fluid/operators/conv_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
109109
}
110110

111111
void Conv2DOpMaker::Make() {
112+
AddAttr<bool>("is_test", "").SetDefault(false);
112113
AddInput(
113114
"Input",
114115
"(Tensor) The input tensor of convolution operator. "

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ class MKLDNNHandler {
191191
mkldnn::memory::primitive_desc& mpd, // NOLINT
192192
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
193193
const std::shared_ptr<mkldnn::memory> user_memory_p,
194-
const std::string& suffix,
195-
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
194+
const std::string& suffix, const std::vector<mkldnn::primitive>& pipeline,
195+
bool is_test = false) { // NOLINT
196196
// create reorder primitive if the input format is not the preferred one
197197
auto local_key = key_ + suffix;
198198
auto key_reorder_p = key_ + suffix + "reorder_p";
@@ -213,7 +213,7 @@ class MKLDNNHandler {
213213
pipeline.push_back(*reorder_p);
214214
}
215215
dev_ctx_.SetBlob(local_key, target_memory_p);
216-
} else {
216+
} else if (!is_test) {
217217
// Make reorder if needed
218218
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
219219
dev_ctx_.GetBlob(key_reorder_p));

0 commit comments

Comments
 (0)