Skip to content

Commit edd1126

Browse files
authored
build_strategy.cc modify mkldnn_enabled_op_types [fluid_ops] (#74417)
* Fix * Fix
1 parent 2cc3e04 commit edd1126

File tree

12 files changed

+47
-36
lines changed

12 files changed

+47
-36
lines changed

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
321321
continue;
322322
}
323323
} else if (pass->Type() == "onednn_placement_pass") {
324-
pass->Set("mkldnn_enabled_op_types",
324+
pass->Set("onednn_enabled_op_types",
325325
new std::unordered_set<std::string>(onednn_enabled_op_types_));
326326
}
327327
VLOG(1) << "Start Apply Pass " << pass->Type();

paddle/fluid/framework/ir/onednn/onednn_placement_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ inline bool FoundPhiOneDNNKernelWithCorrectDataType(
6464
return false;
6565
}
6666

67-
bool MKLDNNPlacementPass::IsSupport(const Node* op) const {
67+
bool ONEDNNPlacementPass::IsSupport(const Node* op) const {
6868
if (FoundOneDNNKernelWithCorrectDataType(op) ||
6969
FoundPhiOneDNNKernelWithCorrectDataType(op)) {
7070
// For interpolate ops, there's a little difference between Paddle and
@@ -89,8 +89,8 @@ bool MKLDNNPlacementPass::IsSupport(const Node* op) const {
8989

9090
} // namespace paddle::framework::ir
9191

92-
REGISTER_PASS(onednn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
93-
.RequirePassAttr("mkldnn_enabled_op_types");
92+
REGISTER_PASS(onednn_placement_pass, paddle::framework::ir::ONEDNNPlacementPass)
93+
.RequirePassAttr("onednn_enabled_op_types");
9494

9595
REGISTER_PASS_CAPABILITY(onednn_placement_pass)
9696
.AddCombination(

paddle/fluid/framework/ir/onednn/onednn_placement_pass.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ namespace ir {
2626
/*
2727
* Specifies which operators should use MKLDNN.
2828
*/
29-
class MKLDNNPlacementPass : public PlacementPassBase {
29+
class ONEDNNPlacementPass : public PlacementPassBase {
3030
protected:
3131
bool IsSupport(const Node* op) const override;
3232

3333
private:
34-
const std::string GetPlacementName() const override { return "MKLDNN"; }
34+
const std::string GetPlacementName() const override { return "ONEDNN"; }
3535

3636
const std::string GetAttrName() const override { return "use_mkldnn"; }
3737

3838
const std::unordered_set<std::string> GetOpTypesList() const override {
39-
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
39+
return Get<std::unordered_set<std::string>>("onednn_enabled_op_types");
4040
}
4141
};
4242

paddle/fluid/framework/ir/onednn/onednn_placement_pass_tester.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class PlacementPassTest {
133133

134134
auto pass = PassRegistry::Instance().Get("onednn_placement_pass");
135135

136-
pass->Set("mkldnn_enabled_op_types",
136+
pass->Set("onednn_enabled_op_types",
137137
new std::unordered_set<std::string>(onednn_enabled_op_types));
138138

139139
graph.reset(pass->Apply(graph.release()));
@@ -143,8 +143,10 @@ class PlacementPassTest {
143143
for (auto* node : graph->Nodes()) {
144144
if (node->IsOp()) {
145145
auto* op = node->Op();
146-
if (op->HasAttr("use_mkldnn") &&
147-
PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))) {
146+
if ((op->HasAttr("use_mkldnn") &&
147+
PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))) ||
148+
(op->HasAttr("use_onednn") &&
149+
PADDLE_GET_CONST(bool, op->GetAttr("use_onednn")))) {
148150
++use_onednn_true_count;
149151
}
150152
}
@@ -156,27 +158,27 @@ class PlacementPassTest {
156158
void PlacementNameTest() {
157159
auto pass = PassRegistry::Instance().Get("onednn_placement_pass");
158160
EXPECT_EQ(static_cast<PlacementPassBase*>(pass.get())->GetPlacementName(),
159-
"MKLDNN");
161+
"ONEDNN");
160162
}
161163
};
162164

163-
TEST(MKLDNNPlacementPass, enable_conv_relu) {
165+
TEST(ONEDNNPlacementPass, enable_conv_relu) {
164166
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 0 pool
165167
PlacementPassTest().MainTest({"conv2d", "relu"}, 4);
166168
}
167169

168-
TEST(MKLDNNPlacementPass, enable_relu_pool) {
170+
TEST(ONEDNNPlacementPass, enable_relu_pool) {
169171
// 1 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool
170172
PlacementPassTest().MainTest({"relu", "pool2d"}, 4);
171173
}
172174

173-
TEST(MKLDNNPlacementPass, enable_all) {
175+
TEST(ONEDNNPlacementPass, enable_all) {
174176
// 2 conv (1 conv is always true) + 2 relu (1 relu is always true) + 1 pool +
175177
// 1 concat
176178
PlacementPassTest().MainTest({}, 6);
177179
}
178180

179-
TEST(MKLDNNPlacementPass, placement_name) {
181+
TEST(ONEDNNPlacementPass, placement_name) {
180182
PlacementPassTest().PlacementNameTest();
181183
}
182184

paddle/fluid/inference/analysis/argument.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ struct Argument {
193193
// whether to mute all logs in inference.
194194
DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);
195195

196-
// Pass a set of op types to enable its mkldnn kernel
197-
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
198-
MKLDNNEnabledOpTypes,
196+
// Pass a set of op types to enable its onednn kernel
197+
DECL_ARGUMENT_FIELD(onednn_enabled_op_types,
198+
ONEDNNEnabledOpTypes,
199199
std::unordered_set<std::string>);
200-
// The cache capacity of different input shapes for mkldnn.
201-
DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
200+
// The cache capacity of different input shapes for onednn.
201+
DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, OnednnCacheCapacity, int);
202202

203203
#ifdef PADDLE_WITH_DNNL
204204
// A set of op types to enable their quantized kernels
@@ -219,7 +219,7 @@ struct Argument {
219219
Bfloat16EnabledOpTypes,
220220
std::unordered_set<std::string>);
221221

222-
DECL_ARGUMENT_FIELD(use_onednn_int8, UseMkldnnInt8, bool);
222+
DECL_ARGUMENT_FIELD(use_onednn_int8, UseOnednnInt8, bool);
223223
#endif
224224

225225
// Passed from config.

paddle/fluid/inference/analysis/ir_pass_manager.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ void IRPassManager::CreatePasses(Argument *argument,
131131
pass->Set("optim_cache_dir", new std::string(std::move(optim_cache_dir)));
132132
pass_num++;
133133
} else if (pass_name == "onednn_placement_pass") {
134-
pass->Set("mkldnn_enabled_op_types",
134+
pass->Set("onednn_enabled_op_types",
135135
new std::unordered_set<std::string>(
136-
argument->mkldnn_enabled_op_types()));
136+
argument->onednn_enabled_op_types()));
137137
} else if (pass_name == "cudnn_placement_pass") {
138138
pass->Set("cudnn_enabled_op_types",
139139
new std::unordered_set<std::string>());

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
10311031
}
10321032
#endif
10331033
#ifdef PADDLE_WITH_DNNL
1034-
} else if (config_.mkldnn_enabled()) {
1035-
// mkldnn
1034+
} else if (config_.onednn_enabled()) {
1035+
// onednn
10361036
pir::IrContext *ctx = pir::IrContext::Instance();
10371037
ctx->GetOrRegisterDialect<paddle::dialect::OneDNNOperatorDialect>();
10381038
if (!config_.custom_pass_only_) {
@@ -2100,9 +2100,9 @@ void AnalysisPredictor::PrepareArgument() {
21002100
argument_->SetIpuCustomPatterns(config_.ipu_custom_patterns_);
21012101
#endif
21022102

2103-
if (config_.mkldnn_enabled() && !config_.use_gpu()) {
2104-
LOG(INFO) << "MKLDNN is enabled";
2105-
argument_->SetMKLDNNEnabledOpTypes(config_.onednn_enabled_op_types_);
2103+
if (config_.onednn_enabled() && !config_.use_gpu()) {
2104+
LOG(INFO) << "ONEDNN is enabled";
2105+
argument_->SetONEDNNEnabledOpTypes(config_.onednn_enabled_op_types_);
21062106
}
21072107

21082108
if (config_.cinn_enabled()) {
@@ -2115,7 +2115,7 @@ void AnalysisPredictor::PrepareArgument() {
21152115
argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
21162116
}
21172117

2118-
if (config_.mkldnn_int8_enabled()) {
2118+
if (config_.onednn_int8_enabled()) {
21192119
LOG(INFO) << "Int8 is enabled";
21202120
argument_->SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
21212121
argument_->SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
@@ -2296,7 +2296,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
22962296
#if defined(_WIN32)
22972297
argument_->PartiallyRelease();
22982298
#else
2299-
if (config_.mkldnn_enabled() ||
2299+
if (config_.onednn_enabled() ||
23002300
(config_.tensorrt_engine_enabled() &&
23012301
config_.tensorrt_precision_mode_ ==
23022302
AnalysisConfig::Precision::kInt8)) { // NOLINT

paddle/fluid/inference/capi/pd_config.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ bool PD_OnednnEnabled(const PD_AnalysisConfig* config) {
311311
config,
312312
common::errors::InvalidArgument(
313313
"The pointer of analysis configuration shouldn't be nullptr"));
314-
return config->config.mkldnn_enabled();
314+
return config->config.onednn_enabled();
315315
}
316316

317317
void PD_SetCpuMathLibraryNumThreads(PD_AnalysisConfig* config,

paddle/fluid/jit/engine/interpreter_engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void InterpreterEngine::CreateInterpreterCore() {
5353
#ifdef PADDLE_WITH_DNNL
5454
auto onednn_pass =
5555
framework::ir::PassRegistry::Instance().Get("onednn_placement_pass");
56-
onednn_pass->Set("mkldnn_enabled_op_types",
56+
onednn_pass->Set("onednn_enabled_op_types",
5757
new std::unordered_set<std::string>({}));
5858
onednn_pass->Apply(&graph);
5959
#endif

paddle/fluid/operators/generator/get_expected_kernel_func.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static bool ReduceOpHasOptimizedOneDNNKernel(
6161
}
6262

6363
// only poolop
64-
bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
64+
bool CanONEDNNSupportPool(const framework::ExecutionContext& ctx) {
6565
if (ctx.Attr<bool>("adaptive") == false) return true;
6666
// oneDNN is supporting only unchangeable in size pool window
6767
auto src_tz = common::vectorize(ctx.Input<phi::DenseTensor>("X")->dims());
@@ -181,7 +181,7 @@ phi::KernelKey GetPoolExpectedKernelType(
181181
auto data_type = op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X");
182182

183183
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_DNNL
184-
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
184+
op_ptr->SetDnnFallback(!CanONEDNNSupportPool(ctx));
185185
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_DNNL
186186

187187
return phi::KernelKey(data_type, ctx.GetPlace());
@@ -194,7 +194,7 @@ phi::KernelKey GetPoolDoubleGradExpectedKernelType(
194194
op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "grad_x@GRAD");
195195

196196
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_DNNL
197-
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
197+
op_ptr->SetDnnFallback(!CanONEDNNSupportPool(ctx));
198198
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_DNNL
199199

200200
return phi::KernelKey(data_type, ctx.GetPlace());

0 commit comments

Comments
 (0)