Skip to content

Commit 9eff1c0

Browse files
committed
Simplify torch test_compiler.
1 parent 7e2e6cb commit 9eff1c0

File tree

3 files changed

+73
-183
lines changed

3 files changed

+73
-183
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ def get_hardward_name(args):
3535

3636

3737
def get_compile_framework_version(args):
38-
if args.compiler == "cinn":
38+
if args.compiler in ["cinn", "nope"]:
3939
return paddle.__version__
40-
if args.compiler == "nope":
41-
return "nope-baseline"
4240
return "unknown"
4341

4442

graph_net/test_compiler_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def check_output_datatype(args, eager_dtypes, compiled_dtypes):
134134
)
135135

136136
# datatype check
137+
# "datatype not match" is recognized as a large loss in analysis process later,
138+
# and is not recognized as a failure here.
137139
type_match = check_type_match(eager_dtypes, compiled_dtypes)
138140
print_with_log_prompt(
139141
"[DataType]",

graph_net/torch/test_compiler.py

Lines changed: 70 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ def set_seed(random_seed):
4343
torch.cuda.manual_seed_all(random_seed)
4444

4545

46+
def get_hardward_name(args):
47+
hardware_name = "unknown"
48+
if "cuda" in args.device:
49+
hardware_name = torch.cuda.get_device_name(args.device)
50+
elif args.device == "cpu":
51+
hardware_name = platform.processor()
52+
return hardware_name
53+
54+
55+
def get_compile_framework_version(args):
56+
if args.compiler in ["inductor", "nope"]:
57+
return torch.__version__
58+
elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]:
59+
# Assuming compiler object has a version attribute
60+
return f"{args.compiler.capitalize()} {compiler.version}"
61+
return "unknown"
62+
63+
4664
def load_class_from_file(
4765
args: argparse.Namespace, class_name: str, device: str
4866
) -> Type[torch.nn.Module]:
@@ -87,31 +105,6 @@ def get_input_dict(args):
87105
}
88106

89107

90-
@dataclass
91-
class DurationBox:
92-
value: float
93-
94-
95-
@contextmanager
96-
def naive_timer(duration_box, synchronizer_func):
97-
synchronizer_func()
98-
start = time.time()
99-
yield
100-
synchronizer_func()
101-
end = time.time()
102-
duration_box.value = (end - start) * 1000 # Store in milliseconds
103-
104-
105-
def get_timing_stats(elapsed_times: List[float]):
106-
stats = {
107-
"mean": float(f"{np.mean(elapsed_times):.3g}"),
108-
"std": float(f"{np.std(elapsed_times):.3g}"),
109-
"min": float(f"{np.min(elapsed_times):.3g}"),
110-
"max": float(f"{np.max(elapsed_times):.3g}"),
111-
}
112-
return stats
113-
114-
115108
def measure_performance(model_call, args, compiler):
116109
stats = {}
117110

@@ -120,27 +113,26 @@ def measure_performance(model_call, args, compiler):
120113
model_call()
121114
compiler.synchronize()
122115

116+
hardware_name = get_hardward_name(args)
117+
print(
118+
f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}",
119+
file=sys.stderr,
120+
flush=True,
121+
)
122+
123123
if "cuda" in args.device:
124124
"""
125125
Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings,
126126
With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench
127127
"""
128128

129-
device = torch.device(args.device)
130-
hardware_name = torch.cuda.get_device_name(device)
131-
print(
132-
f"{args.log_prompt} [Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}",
133-
file=sys.stderr,
134-
flush=True,
135-
)
136-
137129
e2e_times = []
138130
gpu_times = []
139131

140132
for i in range(args.trials):
141133
# End-to-end timing (naive_timer)
142-
duration_box = DurationBox(-1)
143-
with naive_timer(duration_box, compiler.synchronize):
134+
duration_box = test_compiler_util.DurationBox(-1)
135+
with test_compiler_util.naive_timer(duration_box, compiler.synchronize):
144136
# GPU-only timing (CUDA Events)
145137
start_event = torch.cuda.Event(enable_timing=True)
146138
end_event = torch.cuda.Event(enable_timing=True)
@@ -149,7 +141,7 @@ def measure_performance(model_call, args, compiler):
149141
model_call()
150142

151143
end_event.record()
152-
torch.cuda.synchronize(device=device)
144+
compiler.synchronize()
153145

154146
gpu_time_ms = start_event.elapsed_time(end_event)
155147
e2e_times.append(duration_box.value)
@@ -160,29 +152,22 @@ def measure_performance(model_call, args, compiler):
160152
flush=True,
161153
)
162154

163-
stats["e2e"] = get_timing_stats(e2e_times)
164-
stats["gpu"] = get_timing_stats(gpu_times)
155+
stats["e2e"] = test_compiler_util.get_timing_stats(e2e_times)
156+
stats["gpu"] = test_compiler_util.get_timing_stats(gpu_times)
165157

166158
else: # CPU or other devices
167-
hardware_name = platform.processor()
168-
print(
169-
f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}",
170-
file=sys.stderr,
171-
flush=True,
172-
)
173-
174159
e2e_times = []
175160
for i in range(args.trials):
176-
duration_box = DurationBox(-1)
177-
with naive_timer(duration_box, compiler.synchronize):
161+
duration_box = test_compiler_util.DurationBox(-1)
162+
with test_compiler_util.naive_timer(duration_box, compiler.synchronize):
178163
model_call()
179164
print(
180165
f"Trial {i + 1}: e2e={duration_box.value:.5f} ms",
181166
file=sys.stderr,
182167
flush=True,
183168
)
184169
e2e_times.append(duration_box.value)
185-
stats["e2e"] = get_timing_stats(e2e_times)
170+
stats["e2e"] = test_compiler_utilget_timing_stats(e2e_times)
186171

187172
return stats
188173

@@ -191,49 +176,9 @@ def test_single_model(args):
191176
compiler = get_compiler_backend(args)
192177
input_dict = get_input_dict(args)
193178
model = get_model(args, args.device)
194-
model_path = os.path.normpath(args.model_path)
195-
print(f"{args.log_prompt} [Processing] {model_path}", file=sys.stderr, flush=True)
196-
model_name = os.path.basename(model_path)
197-
print(
198-
f"{args.log_prompt} [Config] model: {model_name}", file=sys.stderr, flush=True
199-
)
200-
print(
201-
f"{args.log_prompt} [Config] device: {args.device}", file=sys.stderr, flush=True
202-
)
203-
204-
hardware_name = "unknown"
205-
if "cuda" in args.device:
206-
hardware_name = torch.cuda.get_device_name(args.device)
207-
elif args.device == "cpu":
208-
hardware_name = platform.processor()
209-
print(
210-
f"{args.log_prompt} [Config] hardware: {hardware_name}",
211-
file=sys.stderr,
212-
flush=True,
213-
)
214-
215-
print(
216-
f"{args.log_prompt} [Config] compiler: {args.compiler}",
217-
file=sys.stderr,
218-
flush=True,
219-
)
220-
print(
221-
f"{args.log_prompt} [Config] warmup: {args.warmup}", file=sys.stderr, flush=True
222-
)
223-
print(
224-
f"{args.log_prompt} [Config] trials: {args.trials}", file=sys.stderr, flush=True
225-
)
226179

227-
version_str = "unknown"
228-
if args.compiler == "inductor":
229-
version_str = torch.__version__
230-
elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]:
231-
# Assuming compiler object has a version attribute
232-
version_str = f"{args.compiler.capitalize()} {compiler.version}"
233-
print(
234-
f"{args.log_prompt} [Config] compile_framework_version: {version_str}",
235-
file=sys.stderr,
236-
flush=True,
180+
test_compiler_util.print_basic_config(
181+
args, get_hardward_name(args), get_compile_framework_version(args)
237182
)
238183

239184
runtime_seed = 1024
@@ -245,28 +190,11 @@ def test_single_model(args):
245190
try:
246191
eager_model_call = lambda: model(**input_dict)
247192
eager_stats = measure_performance(eager_model_call, args, compiler)
248-
print(
249-
f"{args.log_prompt} [Performance][eager]: {json.dumps(eager_stats)}",
250-
file=sys.stderr,
251-
flush=True,
252-
)
253193

254194
torch.manual_seed(runtime_seed)
255195
expected_out = eager_model_call()
256196
if not isinstance(expected_out, tuple):
257197
expected_out = (expected_out,)
258-
259-
eager_types = [
260-
str(x.dtype).replace("torch.", "")
261-
if isinstance(x, torch.Tensor)
262-
else type(x).__name__
263-
for x in expected_out
264-
]
265-
print(
266-
f"{args.log_prompt} [Datatype][eager]: {' '.join(eager_types)}",
267-
file=sys.stderr,
268-
flush=True,
269-
)
270198
except (TypeError, RuntimeError) as e:
271199
print(f"Eager model execution failed: {str(e)}", file=sys.stderr)
272200
eager_failure = True
@@ -286,43 +214,12 @@ def test_single_model(args):
286214
torch.manual_seed(runtime_seed)
287215
compiled_model_call = lambda: compiled_model(**input_dict)
288216
compiled_stats = measure_performance(compiled_model_call, args, compiler)
289-
print(
290-
f"{args.log_prompt} [Performance][compiled]: {json.dumps(compiled_stats)}",
291-
file=sys.stderr,
292-
flush=True,
293-
)
294217

295218
compiled_out = compiled_model_call()
296219
if not isinstance(compiled_out, tuple):
297220
compiled_out = (compiled_out,)
298221
if args.compiler == "xla":
299222
compiled_out = tuple(item.to("cpu").to("cuda") for item in compiled_out)
300-
301-
compiled_types = [
302-
str(x.dtype).replace("torch.", "")
303-
if isinstance(x, torch.Tensor)
304-
else type(x).__name__
305-
for x in compiled_out
306-
]
307-
print(
308-
f"{args.log_prompt} [Datatype][compiled]: {' '.join(compiled_types)}",
309-
file=sys.stderr,
310-
flush=True,
311-
)
312-
313-
# datatype check
314-
type_match = all(
315-
eager == compiled for eager, compiled in zip(eager_types, compiled_types)
316-
)
317-
print(
318-
f"{args.log_prompt} [DataType] eager:{eager_types} compiled:{compiled_types} match:{type_match}",
319-
file=sys.stderr,
320-
)
321-
# "datatype not match" is recognized as a large loss in analysis process later,
322-
# and is not recognized as a failure here.
323-
324-
compare_correctness(expected_out, compiled_out, args)
325-
326223
except (TypeError, RuntimeError) as e:
327224
print(f"Compiled model execution failed: {str(e)}", file=sys.stderr)
328225
compiled_failure = True
@@ -342,39 +239,13 @@ def test_single_model(args):
342239
flush=True,
343240
)
344241
else:
242+
compare_correctness(expected_out, compiled_out, args)
243+
345244
print(
346245
f"{args.log_prompt} [Result] status: success", file=sys.stderr, flush=True
347246
)
348247

349-
e2e_speedup = 0
350-
gpu_speedup = 0
351-
352-
eager_e2e_time_ms = eager_stats.get("e2e", {}).get("mean", 0)
353-
compiled_e2e_time_ms = compiled_stats.get("e2e", {}).get("mean", 0)
354-
355-
if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0:
356-
e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms
357-
358-
if "cuda" in args.device:
359-
eager_gpu_time_ms = eager_stats.get("gpu", {}).get("mean", 0)
360-
compiled_gpu_time_ms = compiled_stats.get("gpu", {}).get("mean", 0)
361-
362-
if eager_gpu_time_ms > 0 and compiled_gpu_time_ms > 0:
363-
gpu_speedup = eager_gpu_time_ms / compiled_gpu_time_ms
364-
365-
if e2e_speedup > 0:
366-
print(
367-
f"{args.log_prompt} [Speedup][e2e]: {e2e_speedup:.4f}",
368-
file=sys.stderr,
369-
flush=True,
370-
)
371-
372-
if "cuda" in args.device and gpu_speedup > 0:
373-
print(
374-
f"{args.log_prompt} [Speedup][gpu]: {gpu_speedup:.4f}",
375-
file=sys.stderr,
376-
flush=True,
377-
)
248+
test_compiler_util.print_times_and_speedup(args, eager_stats, compiled_stats)
378249

379250

380251
def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwargs):
@@ -388,22 +259,41 @@ def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwarg
388259

389260

390261
def compare_correctness(expected_out, compiled_out, args):
391-
test_compiler_util.check_equal(
392-
args,
393-
expected_out,
394-
compiled_out,
395-
cmp_equal_func=get_cmp_equal,
396-
)
262+
eager_dtypes = [
263+
str(x.dtype).replace("torch.", "")
264+
if isinstance(x, torch.Tensor)
265+
else type(x).__name__
266+
for x in expected_out
267+
]
268+
compiled_dtypes = [
269+
str(x.dtype).replace("torch.", "")
270+
if isinstance(x, torch.Tensor)
271+
else type(x).__name__
272+
for x in compiled_out
273+
]
397274

398-
test_compiler_util.check_allclose(
399-
args,
400-
expected_out,
401-
compiled_out,
402-
cmp_all_close_func=get_cmp_all_close,
403-
cmp_max_diff_func=get_cmp_max_diff,
404-
cmp_mean_diff_func=get_cmp_mean_diff,
275+
# datatype check
276+
type_match = test_compiler_util.check_output_datatype(
277+
args, eager_dtypes, compiled_dtypes
405278
)
406279

280+
if type_match:
281+
test_compiler_util.check_equal(
282+
args,
283+
expected_out,
284+
compiled_out,
285+
cmp_equal_func=get_cmp_equal,
286+
)
287+
288+
test_compiler_util.check_allclose(
289+
args,
290+
expected_out,
291+
compiled_out,
292+
cmp_all_close_func=get_cmp_all_close,
293+
cmp_max_diff_func=get_cmp_max_diff,
294+
cmp_mean_diff_func=get_cmp_mean_diff,
295+
)
296+
407297

408298
def get_cmp_equal(expected_out, compiled_out):
409299
return " ".join(

0 commit comments

Comments
 (0)