Skip to content

Commit 1001dda

Browse files
authored
split get_places function, add is_custom_device in op_test (#74363)
* rebase latest
1 parent 51c12fb commit 1001dda

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+324
-261
lines changed

test/legacy_test/op_test.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -389,29 +389,38 @@ def convert_uint16_to_float(in_list):
389389
return np.reshape(out, in_list.shape)
390390

391391

392-
def get_places(string_format=False):
392+
def get_places():
393393
places = []
394-
if not string_format:
395-
if (
396-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
397-
in ['1', 'true', 'on']
398-
or not core.is_compiled_with_cuda()
399-
):
400-
places.append(base.CPUPlace())
401-
if core.is_compiled_with_cuda():
402-
places.append(base.CUDAPlace(0))
403-
else:
404-
if (
405-
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
406-
in ['1', 'true', 'on']
407-
or not paddle.is_compiled_with_cuda()
408-
):
409-
places.append('cpu')
410-
if paddle.is_compiled_with_cuda():
411-
places.append('gpu')
394+
if (
395+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
396+
in ['1', 'true', 'on']
397+
or not core.is_compiled_with_cuda()
398+
):
399+
places.append(base.CPUPlace())
400+
if core.is_compiled_with_cuda():
401+
places.append(base.CUDAPlace(0))
402+
if is_custom_device():
403+
dev_type = paddle.device.get_all_custom_device_type()[0]
404+
places.append(base.CustomPlace(dev_type, 0))
412405
return places
413406

414407

408+
def get_devices():
409+
devices = []
410+
if (
411+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
412+
in ['1', 'true', 'on']
413+
or not paddle.is_compiled_with_cuda()
414+
):
415+
devices.append('cpu')
416+
if paddle.is_compiled_with_cuda():
417+
devices.append('gpu')
418+
if is_custom_device():
419+
dev_type = paddle.device.get_all_custom_device_type()[0]
420+
devices.append(f'{dev_type}:0')
421+
return devices
422+
423+
415424
def get_device_place():
416425
if core.is_compiled_with_cuda():
417426
return base.CUDAPlace(0)
@@ -423,6 +432,15 @@ def get_device_place():
423432
return base.CPUPlace()
424433

425434

435+
def is_custom_device():
436+
custom_dev_types = paddle.device.get_all_custom_device_type()
437+
if custom_dev_types and paddle.device.is_compiled_with_custom_device(
438+
custom_dev_types[0]
439+
):
440+
return True
441+
return False
442+
443+
426444
@contextmanager
427445
def auto_parallel_test_guard(test_info_path, generated_test_file_path):
428446
test_info_file, generated_test_file = None, None
@@ -2902,6 +2920,13 @@ def _get_places(self):
29022920
return [place]
29032921
else:
29042922
return []
2923+
elif is_custom_device():
2924+
dev_type = paddle.device.get_all_custom_device_type()[0]
2925+
place = core.CustomPlace(dev_type, 0)
2926+
if core.is_float16_supported(place):
2927+
return [place]
2928+
else:
2929+
return []
29052930
else:
29062931
return []
29072932
places = []
@@ -2931,6 +2956,9 @@ def _get_places(self):
29312956
and not cpu_only
29322957
):
29332958
places.append(core.CUDAPlace(0))
2959+
if is_custom_device():
2960+
dev_type = paddle.device.get_all_custom_device_type()[0]
2961+
places.append(core.CustomPlace(dev_type, 0))
29342962
return places
29352963

29362964
def check_output(

test/legacy_test/test_activation_op.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
convert_float_to_uint16,
2424
get_device_place,
2525
get_places,
26+
is_custom_device,
2627
)
2728
from scipy.special import erf, expit
2829
from utils import static_guard
@@ -497,7 +498,8 @@ def init_shape(self):
497498

498499

499500
@unittest.skipIf(
500-
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
501+
not (core.is_compiled_with_cuda() or is_custom_device())
502+
or core.is_compiled_with_rocm(),
501503
"core is not compiled with CUDA",
502504
)
503505
class TestSigmoidBF16(OpTest):
@@ -1765,7 +1767,8 @@ def init_dtype(self):
17651767

17661768

17671769
@unittest.skipIf(
1768-
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
1770+
not (core.is_compiled_with_cuda() or is_custom_device())
1771+
or core.is_compiled_with_rocm(),
17691772
"core is not compiled with CUDA",
17701773
)
17711774
class TestSqrtBF16(OpTest):
@@ -2037,7 +2040,7 @@ def setUp(self):
20372040
self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
20382041
self.outputs = {'Out': out}
20392042
self.convert_input_output()
2040-
if not core.is_compiled_with_cuda():
2043+
if not (core.is_compiled_with_cuda() or is_custom_device()):
20412044
self.__class__.no_need_check_grad = True
20422045

20432046
def init_shape(self):
@@ -2091,7 +2094,7 @@ def setUp(self):
20912094
self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
20922095
self.outputs = {'Out': out}
20932096
self.convert_input_output()
2094-
if not core.is_compiled_with_cuda():
2097+
if not (core.is_compiled_with_cuda() or is_custom_device()):
20952098
self.__class__.no_need_check_grad = True
20962099

20972100
def init_shape(self):
@@ -4563,7 +4566,8 @@ def init_shape(self):
45634566

45644567

45654568
@unittest.skipIf(
4566-
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
4569+
not (core.is_compiled_with_cuda() or is_custom_device())
4570+
or core.is_compiled_with_rocm(),
45674571
"core is not compiled with CUDA",
45684572
)
45694573
class TestSquareBF16(OpTest):
@@ -4917,7 +4921,8 @@ def init_shape(self):
49174921

49184922

49194923
@unittest.skipIf(
4920-
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
4924+
not (core.is_compiled_with_cuda() or is_custom_device())
4925+
or core.is_compiled_with_rocm(),
49214926
"core is not compiled with CUDA",
49224927
)
49234928
class TestSoftplusBF16(OpTest):
@@ -5595,7 +5600,8 @@ def test_errors(self):
55955600
# ------------------ Test Cudnn Activation----------------------
55965601
def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3):
55975602
@unittest.skipIf(
5598-
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
5603+
not (core.is_compiled_with_cuda() or is_custom_device()),
5604+
"core is not compiled with CUDA",
55995605
)
56005606
class TestActCudnn(parent):
56015607
def init_kernel_type(self):

test/legacy_test/test_adadelta_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest, get_device_place, get_places
18+
from op_test import OpTest, get_device_place, get_devices
1919

2020
import paddle
2121
from paddle import base
@@ -294,7 +294,7 @@ def _test_adadelta_op_dygraph_place_amp(self, place, use_amp=False):
294294
paddle.enable_static()
295295

296296
def test_main(self):
297-
for place in get_places(string_format=True):
297+
for place in get_devices():
298298
use_amp_list = [True, False]
299299
for use_amp in use_amp_list:
300300
self._test_adadelta_op_dygraph_place_amp(place, use_amp)

test/legacy_test/test_adagrad_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
from op import Operator
20-
from op_test import OpTest, get_device_place, get_places
20+
from op_test import OpTest, get_device_place, get_devices, get_places
2121

2222
import paddle
2323
from paddle.base import core
@@ -242,7 +242,7 @@ def _test_adagrad_op_dygraph_place_amp(self, place, use_amp=False):
242242
paddle.enable_static()
243243

244244
def test_main(self):
245-
for place in get_places(string_format=True):
245+
for place in get_devices():
246246
use_amp_list = [True, False]
247247
for use_amp in use_amp_list:
248248
self._test_adagrad_op_dygraph_place_amp(place, use_amp)

test/legacy_test/test_adam_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818
from op import Operator
19-
from op_test import OpTest, get_places
19+
from op_test import OpTest, get_devices, get_places
2020

2121
import paddle
2222
from paddle import base
@@ -1296,7 +1296,7 @@ def _adam_optimize_static(
12961296
return out
12971297

12981298
def _get_places(self):
1299-
return get_places(string_format=True)
1299+
return get_devices()
13001300

13011301
def _check_with_place_amp(self, place, use_amp):
13021302
# test dygraph mode

test/legacy_test/test_adamax_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest, get_device_place, get_places
18+
from op_test import OpTest, get_device_place, get_devices
1919

2020
import paddle
2121

@@ -275,7 +275,7 @@ def _test_adamax_op_dygraph_place_amp(self, place, use_amp=False):
275275
paddle.enable_static()
276276

277277
def _get_places(self):
278-
return get_places(string_format=True)
278+
return get_devices()
279279

280280
def test_main(self):
281281
for place in self._get_places():

test/legacy_test/test_adamw_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from functools import partial
1919

2020
import numpy as np
21-
from op_test import OpTest, get_places
21+
from op_test import OpTest, get_devices
2222

2323
import paddle
2424
from paddle import base, nn
@@ -758,7 +758,7 @@ def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False):
758758
optimizer.clear_grad()
759759

760760
def _get_places(self):
761-
places = get_places(string_format=True)
761+
places = get_devices()
762762
if paddle.is_compiled_with_xpu():
763763
places.append('xpu')
764764
return places

test/legacy_test/test_adaptive_log_softmax_with_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import get_places
18+
from op_test import get_devices, get_places
1919

2020
import paddle
2121
import paddle.optimizer as optim
@@ -58,7 +58,7 @@ def predict(self, input):
5858
class TestNNAdaptiveLogSoftmaxWithLossAPI(unittest.TestCase):
5959
def setUp(self):
6060
paddle.seed(2024)
61-
self.place = get_places(string_format=True)
61+
self.place = get_devices()
6262
self.log_np = np.random.randn(4, 8).astype('float32')
6363
self.predict_np = np.abs(np.random.randn(64, 8).astype('float32'))
6464

test/legacy_test/test_blha_get_max_len_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18+
from op_test import is_custom_device
1819

1920
import paddle
2021
from paddle.base import core
@@ -109,7 +110,8 @@ def test_static_api(self):
109110

110111

111112
@unittest.skipIf(
112-
not core.is_compiled_with_cuda() and not core.is_compiled_with_xpu(),
113+
not (core.is_compiled_with_cuda() or is_custom_device())
114+
and not core.is_compiled_with_xpu(),
113115
"Only support XPU or GPU in CUDA mode.",
114116
)
115117
class TestBlhaGetMaxLenOp_ZeroSize(unittest.TestCase):

test/legacy_test/test_cartesian_prod.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from itertools import product
1717

1818
import numpy as np
19-
from op_test import get_places
19+
from op_test import get_devices
2020

2121
import paddle
2222
from paddle.base import core
@@ -36,7 +36,7 @@ def setUp(self):
3636
self.c_np = np.random.random(self.c_shape).astype(self.dtype_np)
3737
self.d_np = np.empty(0, self.dtype_np)
3838

39-
self.place = get_places(string_format=True)
39+
self.place = get_devices()
4040

4141
def init_setting(self):
4242
self.dtype_np = 'float32'
@@ -119,7 +119,7 @@ def setUp(self):
119119
self.a_np = np.random.random(self.a_shape).astype(self.dtype_np)
120120
self.b_np = np.empty(0, self.dtype_np)
121121

122-
self.place = get_places(string_format=True)
122+
self.place = get_devices()
123123

124124
def init_setting(self):
125125
self.dtype_np = 'float32'

0 commit comments

Comments
 (0)