Skip to content

Commit fad5ec1

Browse files
authored
[Feature Enhancement] test_compiler support dcu and add the check of GPU utilization before test. (#322)
* Add the check of GPU utilization before test. * Use paddle.nn.initializer.TruncatedNormal to initilize tensor. * test_compiler support dcu. * Add the missing return.
1 parent acbc3e3 commit fad5ec1

File tree

3 files changed

+174
-41
lines changed

3 files changed

+174
-41
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def set_seed(random_seed):
2525

2626

2727
def get_hardward_name(args):
28-
if args.device == "cuda":
28+
if test_compiler_util.is_gpu_device(args.device):
2929
hardware = paddle.device.cuda.get_device_name(0)
3030
elif args.device == "cpu":
3131
hardware = platform.processor()
@@ -64,15 +64,15 @@ def get_synchronizer_func(args):
6464
return paddle.device.synchronize
6565

6666

67-
def get_model(args):
67+
def get_model(model_path):
6868
model_class = load_class_from_file(
69-
f"{args.model_path}/model.py", class_name="GraphModule"
69+
f"{model_path}/model.py", class_name="GraphModule"
7070
)
7171
return model_class()
7272

7373

74-
def get_input_dict(args):
75-
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
74+
def get_input_dict(model_path):
75+
inputs_params = utils.load_converted_from_text(f"{model_path}")
7676
params = inputs_params["weight_info"]
7777
inputs = inputs_params["input_info"]
7878

@@ -81,8 +81,8 @@ def get_input_dict(args):
8181
return state_dict
8282

8383

84-
def get_input_spec(args):
85-
inputs_params_list = utils.load_converted_list_from_text(f"{args.model_path}")
84+
def get_input_spec(model_path):
85+
inputs_params_list = utils.load_converted_list_from_text(f"{model_path}")
8686
input_spec = [None] * len(inputs_params_list)
8787
for i, v in enumerate(inputs_params_list):
8888
dtype = v["info"]["dtype"]
@@ -94,7 +94,7 @@ def get_input_spec(args):
9494
def get_compiled_model(args, model):
9595
if args.compiler == "nope":
9696
return model
97-
input_spec = get_input_spec(args)
97+
input_spec = get_input_spec(args.model_path)
9898
build_strategy = paddle.static.BuildStrategy()
9999
compiled_model = paddle.jit.to_static(
100100
model,
@@ -110,7 +110,7 @@ def get_compiled_model(args, model):
110110
def get_static_model(args, model):
111111
static_model = paddle.jit.to_static(
112112
model,
113-
input_spec=get_input_spec(args),
113+
input_spec=get_input_spec(args.model_path),
114114
full_graph=True,
115115
backend=None,
116116
)
@@ -138,7 +138,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
138138
flush=True,
139139
)
140140

141-
if "cuda" in args.device:
141+
if test_compiler_util.is_gpu_device(args.device):
142142
"""
143143
Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings,
144144
With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench
@@ -249,8 +249,8 @@ def transfer_to_float(origin_outputs):
249249

250250
def test_single_model(args):
251251
synchronizer_func = get_synchronizer_func(args)
252-
input_dict = get_input_dict(args)
253-
model = get_model(args)
252+
input_dict = get_input_dict(args.model_path)
253+
model = get_model(args.model_path)
254254
model.eval()
255255

256256
test_compiler_util.print_basic_config(
@@ -259,6 +259,7 @@ def test_single_model(args):
259259

260260
# Run on eager mode
261261
eager_success = False
262+
eager_time_stats = {}
262263
try:
263264
print("Run model in eager mode.", file=sys.stderr, flush=True)
264265
static_model = get_static_model(args, model)
@@ -275,6 +276,7 @@ def test_single_model(args):
275276

276277
# Run on compiling mode
277278
compiled_success = False
279+
compiled_time_stats = {}
278280
try:
279281
print("Run model in compiled mode.", file=sys.stderr, flush=True)
280282
compiled_model = get_compiled_model(args, model)
@@ -293,9 +295,9 @@ def test_single_model(args):
293295
if eager_success and compiled_success:
294296
check_outputs(args, expected_out, compiled_out)
295297

296-
test_compiler_util.print_times_and_speedup(
297-
args, eager_time_stats, compiled_time_stats
298-
)
298+
test_compiler_util.print_times_and_speedup(
299+
args, eager_time_stats, compiled_time_stats
300+
)
299301

300302

301303
def get_cmp_equal(expected_out, compiled_out):
@@ -366,20 +368,12 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
366368

367369

368370
def test_multi_models(args):
369-
test_samples = None
370-
if args.allow_list is not None:
371-
assert os.path.isfile(args.allow_list)
372-
graphnet_root = path_utils.get_graphnet_root()
373-
print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True)
374-
verified_samples = []
375-
with open(args.verified_samples_list_path, "r") as f:
376-
for line in f.readlines():
377-
test_samples.append(os.path.join(graphnet_root, line.strip()))
371+
test_samples = test_compiler_util.get_allow_samples(args.allow_list)
378372

379373
sample_idx = 0
380374
failed_samples = []
381375
for model_path in path_utils.get_recursively_model_path(args.model_path):
382-
if verified_samples is None or os.path.abspath(model_path) in verified_samples:
376+
if test_samples is None or os.path.abspath(model_path) in test_samples:
383377
print(
384378
f"[{sample_idx}] test_compiler, model_path: {model_path}",
385379
file=sys.stderr,
@@ -415,11 +409,24 @@ def test_multi_models(args):
415409
def main(args):
416410
assert os.path.isdir(args.model_path)
417411
assert args.compiler in {"cinn", "nope"}
412+
assert args.device in ["cuda", "dcu", "cpu"]
418413

419414
initalize_seed = 123
420415
set_seed(random_seed=initalize_seed)
421416

422417
if path_utils.is_single_model_dir(args.model_path):
418+
if paddle.device.is_compiled_with_cuda():
419+
device_id = int(paddle.device.get_device().split(":")[-1])
420+
device_count = paddle.device.cuda.device_count()
421+
gpu_util, mem_util = test_compiler_util.get_device_utilization(
422+
device_id, device_count, get_synchronizer_func(args)
423+
)
424+
if gpu_util is not None and mem_util is not None:
425+
print(
426+
f"Device status: gpu_id {device_id}, gpu_util {gpu_util:.2f}%, mem_util {mem_util:.2f}%",
427+
file=sys.stderr,
428+
flush=True,
429+
)
423430
test_single_model(args)
424431
else:
425432
test_multi_models(args)

graph_net/paddle/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,10 @@ def replay_tensor(info):
214214
else:
215215
if mean is not None and std is not None:
216216
tensor = paddle.empty(shape=shape, dtype=dtype)
217-
paddle.nn.init.trunc_normal_(
218-
tensor=tensor, mean=mean, std=std, a=min_val, b=max_val
217+
initializer = paddle.nn.initializer.TruncatedNormal(
218+
mean=mean, std=std, a=min_val, b=max_val
219219
)
220+
initializer(tensor)
220221
return tensor.to(dtype).to(device)
221222
else:
222223
return (

graph_net/test_compiler_util.py

Lines changed: 139 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import sys
44
import json
55
import time
6+
import subprocess
7+
import shutil
68
import numpy as np
79
from dataclasses import dataclass
810
from contextlib import contextmanager
911

12+
from graph_net import path_utils
13+
1014

1115
@dataclass
1216
class DurationBox:
@@ -23,6 +27,103 @@ def naive_timer(duration_box, synchronizer_func):
2327
duration_box.value = (end - start) * 1000 # Store in milliseconds
2428

2529

30+
def is_gpu_device(device):
31+
return "cuda" in device or "dcu" in device
32+
33+
34+
def get_device_utilization(device_id, device_count, synchronizer_func):
35+
current_pid = os.getpid()
36+
37+
if shutil.which("nvidia-smi"):
38+
try:
39+
cuda_devices_str = os.getenv("CUDA_VISIBLE_DEVICES", "")
40+
if cuda_devices_str != "":
41+
cuda_devices = list(map(int, cuda_devices_str.split(",")))
42+
else:
43+
cuda_devices = list(range(device_count))
44+
selected_gpu_id = cuda_devices[device_id]
45+
46+
print(
47+
f"Check the status of GPU {selected_gpu_id} for 5 times.",
48+
file=sys.stderr,
49+
flush=True,
50+
)
51+
selected_gpu_uuid, max_gpu_util, max_mem_util = None, 0.0, 0.0
52+
for i in range(5):
53+
synchronizer_func()
54+
time.sleep(1)
55+
56+
output = (
57+
subprocess.check_output(
58+
[
59+
"nvidia-smi",
60+
f"--query-gpu=index,gpu_uuid,utilization.gpu,memory.used,memory.total",
61+
"--format=csv,noheader,nounits",
62+
]
63+
)
64+
.decode()
65+
.strip()
66+
)
67+
for line in output.split("\n"):
68+
if line.strip():
69+
(
70+
gpu_id,
71+
selected_gpu_uuid,
72+
gpu_util,
73+
used_mem,
74+
mem_total,
75+
) = line.split(", ")
76+
if int(gpu_id) == selected_gpu_id:
77+
break
78+
79+
gpu_util = float(gpu_util)
80+
mem_util = float(used_mem) * 100 / float(mem_total)
81+
print(
82+
f"- gpu_id: {selected_gpu_id}, gpu_uuid: {selected_gpu_uuid}, gpu_util: {gpu_util:.2f}%, used_mem: {used_mem}, mem_total: {mem_total}",
83+
file=sys.stderr,
84+
flush=True,
85+
)
86+
87+
max_gpu_util = gpu_util if gpu_util > max_gpu_util else max_gpu_util
88+
max_mem_util = mem_util if mem_util > max_mem_util else max_mem_util
89+
90+
other_tasks = []
91+
output = (
92+
subprocess.check_output(
93+
[
94+
"nvidia-smi",
95+
f"--query-compute-apps=gpu_uuid,pid,used_memory",
96+
"--format=csv,noheader,nounits",
97+
]
98+
)
99+
.decode()
100+
.strip()
101+
)
102+
for line in output.split("\n"):
103+
if line.strip():
104+
gpu_uuid, pid, used_memory = line.split(", ")
105+
if gpu_uuid == selected_gpu_uuid and int(pid) != current_pid:
106+
other_tasks.append(line)
107+
# Note: in docker container, the current_pid maybe different from that captured by nvidia-smi.
108+
print(
109+
f"Note: There are {len(other_tasks)} tasks running on GPU {selected_gpu_id} (current_pid:{current_pid}).",
110+
file=sys.stderr,
111+
flush=True,
112+
)
113+
for task in other_tasks:
114+
gpu_uuid, pid, used_memory = task.split(", ")
115+
print(
116+
f"- gpu_uuid:{gpu_uuid}, pid:{pid}, used_memory:{used_memory}",
117+
file=sys.stderr,
118+
flush=True,
119+
)
120+
return max_gpu_util, max_mem_util
121+
except subprocess.CalledProcessError:
122+
pass
123+
124+
return None, None
125+
126+
26127
def get_timing_stats(elapsed_times):
27128
stats = {
28129
"mean": float(f"{np.mean(elapsed_times):.6g}"),
@@ -75,24 +176,33 @@ def print_basic_config(args, hardware_name, compile_framework_version):
75176
)
76177

77178

78-
def print_running_status(args, eager_success, compiled_success):
179+
def print_running_status(args, eager_success, compiled_success=None):
79180
def convert_to_str(b):
80181
return "success" if b else "failed"
81182

82-
print_with_log_prompt(
83-
"[Result][status]",
84-
f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}",
85-
args.log_prompt,
86-
)
183+
if compiled_success is not None:
184+
print_with_log_prompt(
185+
"[Result][status]",
186+
f"eager:{convert_to_str(eager_success)} compiled:{convert_to_str(compiled_success)}",
187+
args.log_prompt,
188+
)
189+
else:
190+
print_with_log_prompt(
191+
"[Result][status]",
192+
f"eager:{convert_to_str(eager_success)}",
193+
args.log_prompt,
194+
)
87195

88196

89197
def print_times_and_speedup(args, eager_stats, compiled_stats):
90-
print_with_log_prompt(
91-
"[Performance][eager]:", json.dumps(eager_stats), args.log_prompt
92-
)
93-
print_with_log_prompt(
94-
"[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt
95-
)
198+
if not eager_stats:
199+
print_with_log_prompt(
200+
"[Performance][eager]:", json.dumps(eager_stats), args.log_prompt
201+
)
202+
if not compiled_stats:
203+
print_with_log_prompt(
204+
"[Performance][compiled]:", json.dumps(compiled_stats), args.log_prompt
205+
)
96206

97207
e2e_speedup = 0
98208
gpu_speedup = 0
@@ -103,7 +213,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats):
103213
if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0:
104214
e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms
105215

106-
if "cuda" in args.device:
216+
if is_gpu_device(args.device):
107217
eager_gpu_time_ms = eager_stats.get("gpu", {}).get("mean", 0)
108218
compiled_gpu_time_ms = compiled_stats.get("gpu", {}).get("mean", 0)
109219

@@ -113,7 +223,7 @@ def print_times_and_speedup(args, eager_stats, compiled_stats):
113223
if e2e_speedup > 0:
114224
print_with_log_prompt("[Speedup][e2e]:", f"{e2e_speedup:.5f}", args.log_prompt)
115225

116-
if "cuda" in args.device and gpu_speedup > 0:
226+
if is_gpu_device(args.device) and gpu_speedup > 0:
117227
print_with_log_prompt("[Speedup][gpu]:", f"{gpu_speedup:.5f}", args.log_prompt)
118228

119229

@@ -224,3 +334,18 @@ def check_allclose(
224334
compiled_out=compiled_out,
225335
**kwargs,
226336
)
337+
338+
339+
def get_allow_samples(allow_list):
340+
if allow_list is None:
341+
return None
342+
343+
assert os.path.isfile(allow_list), f"{allow_list} is not a regular file."
344+
graphnet_root = path_utils.get_graphnet_root()
345+
print(f"graphnet_root: {graphnet_root}", file=sys.stderr, flush=True)
346+
test_samples = []
347+
with open(allow_list, "r") as f:
348+
for line in f.readlines():
349+
test_samples.append(os.path.join(graphnet_root, line.strip()))
350+
351+
return test_samples

0 commit comments

Comments
 (0)