Skip to content

Commit 90e4213

Browse files
authored
test directory modify FLAGS_use_mkldnn [fluid_ops] (#74423)
1 parent de04d9e commit 90e4213

12 files changed

+46
-46
lines changed

test/deprecated/legacy_test/test_eager_deletion_delete_vars_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818

19-
os.environ['FLAGS_use_mkldnn'] = '0'
19+
os.environ['FLAGS_use_onednn'] = '0'
2020
os.environ['CPU_NUM'] = '4'
2121

2222
import unittest

test/deprecated/quantization/CMakeLists.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function(_inference_analysis_python_api_int8_test target model_dir data_path
1111
SRCS ${filename}
1212
ENVS
1313
CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
14-
FLAGS_use_mkldnn=${use_mkldnn}
14+
FLAGS_use_onednn=${use_mkldnn}
1515
ARGS
1616
--infer_model
1717
${model_dir}/model
@@ -67,7 +67,7 @@ function(inference_quant_int8_image_classification_test target quant_model_dir
6767
ENVS
6868
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
6969
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
70-
FLAGS_use_mkldnn=true
70+
FLAGS_use_onednn=true
7171
ARGS
7272
--quant_model
7373
${quant_model_dir}
@@ -91,7 +91,7 @@ function(inference_quant2_int8_image_classification_test target quant_model_dir
9191
ENVS
9292
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
9393
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
94-
FLAGS_use_mkldnn=true
94+
FLAGS_use_onednn=true
9595
ARGS
9696
--quant_model
9797
${quant_model_dir}
@@ -123,7 +123,7 @@ function(
123123
ENVS
124124
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
125125
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
126-
FLAGS_use_mkldnn=true
126+
FLAGS_use_onednn=true
127127
ARGS
128128
--quant_model
129129
${quant_model_dir}

test/dygraph_to_static/test_build_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def test_resnet(self):
8080

8181
@test_default_mode_only
8282
def test_in_static_mode_mkldnn(self):
83-
paddle.set_flags({'FLAGS_use_mkldnn': True})
83+
paddle.set_flags({'FLAGS_use_onednn': True})
8484
try:
8585
if paddle.base.core.is_compiled_with_mkldnn():
8686
self.resnet_helper.train(True, self.build_strategy)
8787
finally:
88-
paddle.set_flags({'FLAGS_use_mkldnn': False})
88+
paddle.set_flags({'FLAGS_use_onednn': False})
8989

9090

9191
class TestError(Dy2StTestBase):

test/dygraph_to_static/test_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ def test_mnist_to_static(self):
176176
@test_default_mode_only
177177
def test_mnist_declarative_cpu_vs_mkldnn(self):
178178
dygraph_loss_cpu = self.train_dygraph()
179-
paddle.set_flags({'FLAGS_use_mkldnn': True})
179+
paddle.set_flags({'FLAGS_use_onednn': True})
180180
try:
181181
dygraph_loss_mkldnn = self.train_dygraph()
182182
finally:
183-
paddle.set_flags({'FLAGS_use_mkldnn': False})
183+
paddle.set_flags({'FLAGS_use_onednn': False})
184184
np.testing.assert_allclose(
185185
dygraph_loss_cpu,
186186
dygraph_loss_mkldnn,

test/dygraph_to_static/test_resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,12 @@ def test_resnet_composite(self):
481481

482482
@test_default_mode_only
483483
def test_in_static_mode_mkldnn(self):
484-
paddle.set_flags({'FLAGS_use_mkldnn': True})
484+
paddle.set_flags({'FLAGS_use_onednn': True})
485485
try:
486486
if paddle.base.core.is_compiled_with_mkldnn():
487487
self.train(to_static=True)
488488
finally:
489-
paddle.set_flags({'FLAGS_use_mkldnn': False})
489+
paddle.set_flags({'FLAGS_use_onednn': False})
490490

491491

492492
if __name__ == '__main__':

test/legacy_test/op_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,11 +2176,11 @@ def check_inplace_output_with_place(
21762176
else:
21772177
# TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn
21782178
# skip op that use_mkldnn currently
2179-
flags_use_mkldnn = base.core.globals()["FLAGS_use_mkldnn"]
2179+
flags_use_onednn = base.core.globals()["FLAGS_use_onednn"]
21802180
attrs_use_mkldnn = hasattr(self, 'attrs') and bool(
21812181
self.attrs.get('use_mkldnn', False)
21822182
)
2183-
if flags_use_mkldnn or attrs_use_mkldnn:
2183+
if flags_use_onednn or attrs_use_mkldnn:
21842184
warnings.warn(
21852185
"check inplace_grad for ops using mkldnn is not supported"
21862186
)
@@ -2216,7 +2216,7 @@ def check_output_with_place(
22162216
core._set_prim_all_enabled(False)
22172217
core.set_prim_eager_enabled(False)
22182218
if not self.is_onednn_op():
2219-
set_flags({"FLAGS_use_mkldnn": False})
2219+
set_flags({"FLAGS_use_onednn": False})
22202220

22212221
if hasattr(self, "use_custom_device") and self.use_custom_device:
22222222
check_dygraph = False
@@ -3285,7 +3285,7 @@ def check_grad_with_place(
32853285
check_dygraph = False
32863286

32873287
if not self.is_onednn_op():
3288-
set_flags({"FLAGS_use_mkldnn": False})
3288+
set_flags({"FLAGS_use_onednn": False})
32893289

32903290
core._set_prim_all_enabled(False)
32913291
core.set_prim_eager_enabled(False)

test/legacy_test/test_sgd_op_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class TestSGDOpBF16API(unittest.TestCase):
239239
@classmethod
240240
def setUpClass(cls):
241241
np.random.seed(12345)
242-
base.set_flags({'FLAGS_use_mkldnn': True})
242+
base.set_flags({'FLAGS_use_onednn': True})
243243

244244
def setUp(self):
245245
self.sample_count = 20

test/mkldnn/check_flags_mkldnn_ops_on_off.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
def check():
2626
print(
27-
"check: _global_flags()['FLAGS_use_mkldnn']=",
28-
_global_flags()["FLAGS_use_mkldnn"],
27+
"check: _global_flags()['FLAGS_use_onednn']=",
28+
_global_flags()["FLAGS_use_onednn"],
2929
)
3030
print(
31-
"check: base.get_flags('FLAGS_use_mkldnn')=",
32-
base.get_flags(['FLAGS_use_mkldnn']),
31+
"check: base.get_flags('FLAGS_use_onednn')=",
32+
base.get_flags(['FLAGS_use_onednn']),
3333
)
3434
print("check: DNNL_VERBOSE=", os.environ['DNNL_VERBOSE'])
3535
print(

test/mkldnn/check_flags_use_mkldnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424

2525
def check():
2626
print(
27-
"check: _global_flags()['FLAGS_use_mkldnn']=",
28-
_global_flags()["FLAGS_use_mkldnn"],
27+
"check: _global_flags()['FLAGS_use_onednn']=",
28+
_global_flags()["FLAGS_use_onednn"],
2929
)
3030
print(
31-
"check: base.get_flags('FLAGS_use_mkldnn')=",
32-
base.get_flags(['FLAGS_use_mkldnn']),
31+
"check: base.get_flags('FLAGS_use_onednn')=",
32+
base.get_flags(['FLAGS_use_onednn']),
3333
)
3434
print("check: DNNL_VERBOSE=", os.environ['DNNL_VERBOSE'])
3535
a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)

test/mkldnn/test_flags_mkldnn_ops_on_off.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import unittest
2020

2121

22-
class TestFlagsUseMkldnn(unittest.TestCase):
22+
class TestFlagsUseOnednn(unittest.TestCase):
2323
def setUp(self):
2424
self._python_interp = sys.executable
2525
self._python_interp += " check_flags_mkldnn_ops_on_off.py"
2626

2727
self.env = os.environ.copy()
2828
self.env["DNNL_VERBOSE"] = "1"
29-
self.env["FLAGS_use_mkldnn"] = "1"
29+
self.env["FLAGS_use_onednn"] = "1"
3030

3131
self.relu_regex = b"^onednn_verbose,exec,cpu,eltwise,.+alg:eltwise_relu alpha:0 beta:0,10x20x20"
3232
self.ew_add_regex = (
@@ -36,7 +36,7 @@ def setUp(self):
3636
b"^onednn_verbose,exec,cpu,matmul,.*10x20x30:10x30x20:10x20x20"
3737
)
3838

39-
def flags_use_mkl_dnn_common(self, e):
39+
def flags_use_onednn_common(self, e):
4040
cmd = self._python_interp
4141
env = dict(self.env, **e)
4242
proc = subprocess.Popen(
@@ -66,46 +66,46 @@ def not_found(self, regex, out, err):
6666
_not_found = not re.search(regex, out, re.MULTILINE)
6767
return self._print_when_false(_not_found, out, err)
6868

69-
def test_flags_use_mkl_dnn_on_empty_off_empty(self):
70-
out, err = self.flags_use_mkl_dnn_common({})
69+
def test_flags_use_onednn_on_empty_off_empty(self):
70+
out, err = self.flags_use_onednn_common({})
7171
assert self.found(self.relu_regex, out, err)
7272
assert self.found(self.ew_add_regex, out, err)
7373
assert self.found(self.matmul_regex, out, err)
7474

75-
def test_flags_use_mkl_dnn_on(self):
75+
def test_flags_use_onednn_on(self):
7676
env = {"FLAGS_tracer_onednn_ops_on": "relu"}
77-
out, err = self.flags_use_mkl_dnn_common(env)
77+
out, err = self.flags_use_onednn_common(env)
7878
assert self.found(self.relu_regex, out, err)
7979
assert self.not_found(self.ew_add_regex, out, err)
8080
assert self.not_found(self.matmul_regex, out, err)
8181

82-
def test_flags_use_mkl_dnn_on_multiple(self):
82+
def test_flags_use_onednn_on_multiple(self):
8383
env = {"FLAGS_tracer_onednn_ops_on": "relu,elementwise_add"}
84-
out, err = self.flags_use_mkl_dnn_common(env)
84+
out, err = self.flags_use_onednn_common(env)
8585
assert self.found(self.relu_regex, out, err)
8686
assert self.found(self.ew_add_regex, out, err)
8787
assert self.not_found(self.matmul_regex, out, err)
8888

89-
def test_flags_use_mkl_dnn_off(self):
89+
def test_flags_use_onednn_off(self):
9090
env = {"FLAGS_tracer_onednn_ops_off": "matmul_v2"}
91-
out, err = self.flags_use_mkl_dnn_common(env)
91+
out, err = self.flags_use_onednn_common(env)
9292
assert self.found(self.relu_regex, out, err)
9393
assert self.found(self.ew_add_regex, out, err)
9494
assert self.not_found(self.matmul_regex, out, err)
9595

96-
def test_flags_use_mkl_dnn_off_multiple(self):
96+
def test_flags_use_onednn_off_multiple(self):
9797
env = {"FLAGS_tracer_onednn_ops_off": "matmul_v2,relu"}
98-
out, err = self.flags_use_mkl_dnn_common(env)
98+
out, err = self.flags_use_onednn_common(env)
9999
assert self.not_found(self.relu_regex, out, err)
100100
assert self.found(self.ew_add_regex, out, err)
101101
assert self.not_found(self.matmul_regex, out, err)
102102

103-
def test_flags_use_mkl_dnn_on_off(self):
103+
def test_flags_use_onednn_on_off(self):
104104
env = {
105105
"FLAGS_tracer_onednn_ops_on": "elementwise_add",
106106
"FLAGS_tracer_onednn_ops_off": "matmul_v2",
107107
}
108-
out, err = self.flags_use_mkl_dnn_common(env)
108+
out, err = self.flags_use_onednn_common(env)
109109
assert self.not_found(self.relu_regex, out, err)
110110
assert self.found(self.ew_add_regex, out, err)
111111
assert self.not_found(self.matmul_regex, out, err)

0 commit comments

Comments
 (0)