Skip to content

Commit 6a6ccc5

Browse files
committed
Update
1 parent e92d660 commit 6a6ccc5

File tree

1 file changed

+21
-32
lines changed

1 file changed

+21
-32
lines changed

graph_net/torch/test_compiler.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,25 @@ def get_timing_stats_cpu(elapsed_times: list[float]):
146146
return stats
147147

148148

149+
def measure_performance(model_call, args, compiler):
150+
if args.device == "cuda":
151+
times = time_execution_with_cuda_event(
152+
model_call,
153+
num_warmup=args.warmup,
154+
num_trials=args.trials,
155+
device=torch.device("cuda:0"),
156+
)
157+
return get_timing_stats(times)
158+
else:
159+
times = time_execution_naive(
160+
model_call,
161+
compiler.synchronize,
162+
num_warmup=args.warmup,
163+
num_trials=args.trials,
164+
)
165+
return get_timing_stats_cpu(times)
166+
167+
149168
def test_single_model(args):
150169
compiler = get_compiler_backend(args)
151170
input_dict = get_input_dict(args)
@@ -192,38 +211,8 @@ def test_single_model(args):
192211
eager_model_call = lambda: model(**input_dict)
193212
compiled_model_call = lambda: compiled_model(**input_dict)
194213

195-
if args.device == "cuda":
196-
eager_times = time_execution_with_cuda_event(
197-
eager_model_call,
198-
num_warmup=args.warmup,
199-
num_trials=args.trials,
200-
device=torch.device("cuda:0"),
201-
)
202-
eager_stats = get_timing_stats(eager_times)
203-
204-
compiled_times = time_execution_with_cuda_event(
205-
compiled_model_call,
206-
num_warmup=args.warmup,
207-
num_trials=args.trials,
208-
device=torch.device("cuda:0"),
209-
)
210-
compiled_stats = get_timing_stats(compiled_times)
211-
else:
212-
eager_times = time_execution_naive(
213-
eager_model_call,
214-
compiler.synchronize,
215-
num_warmup=args.warmup,
216-
num_trials=args.trials,
217-
)
218-
eager_stats = get_timing_stats_cpu(eager_times)
219-
220-
compiled_times = time_execution_naive(
221-
compiled_model_call,
222-
compiler.synchronize,
223-
num_warmup=args.warmup,
224-
num_trials=args.trials,
225-
)
226-
compiled_stats = get_timing_stats_cpu(compiled_times)
214+
eager_stats = measure_performance(eager_model_call, args, compiler)
215+
compiled_stats = measure_performance(compiled_model_call, args, compiler)
227216

228217
expected_out = eager_model_call()
229218
compiled_out = compiled_model_call()

0 commit comments

Comments
 (0)