Skip to content

Commit f7ba459

Browse files
committed
test_compiler supports XPU.
1 parent c3c8599 commit f7ba459

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

graph_net/paddle/backend/cinn_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ def synchronize(self):
1919
if (
2020
paddle.device.is_compiled_with_cuda()
2121
or paddle.device.is_compiled_with_rocm()
22+
or paddle.device.is_compiled_with_xpu()
2223
):
23-
paddle.device.synchronize()
24+
paddle.device.synchronize()

graph_net/paddle/backend/nope_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ def synchronize(self):
1010
if (
1111
paddle.device.is_compiled_with_cuda()
1212
or paddle.device.is_compiled_with_rocm()
13+
or paddle.device.is_compiled_with_xpu()
1314
):
14-
paddle.device.synchronize()
15+
paddle.device.synchronize()

graph_net/paddle/test_compiler.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import random
1313
import platform
1414
import traceback
15+
import subprocess
16+
import re
1517

1618
from graph_net.paddle import utils
1719
from graph_net import path_utils
@@ -40,12 +42,21 @@ def set_seed(random_seed):
4042

4143

4244
def get_hardward_name(args):
45+
hardware = "unknown"
4346
if test_compiler_util.is_gpu_device(args.device):
4447
hardware = paddle.device.cuda.get_device_name(0)
48+
elif args.device == "xpu":
49+
try:
50+
output = subprocess.check_output(["xpu-smi", "-L"], text=True)
51+
hardware = next(
52+
match.group(2)
53+
for line in output.splitlines()
54+
if (match := re.match(r"XPU\s+(\d+):\s+(.+?)\s+\(UUID:\s*([^)]+)\)", line))
55+
)
56+
except Exception as e:
57+
pass
4558
elif args.device == "cpu":
4659
hardware = platform.processor()
47-
else:
48-
hardware = "unknown"
4960
return hardware
5061

5162

@@ -422,7 +433,7 @@ def test_multi_models(args):
422433
def main(args):
423434
assert os.path.isdir(args.model_path)
424435
assert args.compiler in {"cinn", "nope"}
425-
assert args.device in ["cuda", "dcu", "cpu"]
436+
assert args.device in ["cuda", "dcu", "xpu", "cpu"]
426437

427438
initalize_seed = 123
428439
set_seed(random_seed=initalize_seed)

graph_net/paddle/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def convert_to_valid_number(data_type, value):
139139

140140

141141
def convert_meta_classes_to_tensors(file_path):
142+
current_device = paddle.device.get_device()
142143
for name, cls in _get_classes(file_path):
143144
attrs = {
144145
k: v
@@ -159,7 +160,7 @@ def convert_meta_classes_to_tensors(file_path):
159160
"info": {
160161
"shape": attrs.get("shape", []),
161162
"dtype": data_type,
162-
"device": attrs.get("device", "gpu"),
163+
"device": attrs.get("device", current_device),
163164
"mean": convert_to_valid_number(data_type, attrs.get("mean", None)),
164165
"std": convert_to_valid_number(data_type, attrs.get("std", None)),
165166
"min_val": convert_to_valid_number(data_type, attrs.get("min_val", 0)),

0 commit comments

Comments
 (0)