Skip to content

Commit b377227

Browse files
anijain2305facebook-github-bot
authored andcommitted
Set model name early to keep warmup and main model same (#159231)
Summary: X-link: pytorch/pytorch#159231 Approved by: https://github.com/williamwen42 ghstack dependencies: #159209 Reviewed By: ZainRizvi Differential Revision: D79112571 fbshipit-source-id: fbaf66b7e28f639684fe4e59f6c54b8a69736b6a
1 parent bed0bfe commit b377227

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,6 +2425,8 @@ def warmup(fn, model, example_inputs, mode, niters=10):
24252425
# Use distributed wrapping as necessary
24262426
model = self.deepcopy_and_maybe_parallelize(model)
24272427

2428+
if not hasattr(model, name):
2429+
model.name = name
24282430
self.init_optimizer(name, current_device, model.parameters())
24292431

24302432
# The self.autocast context is needed for the model we export with aot_compile,
@@ -2528,8 +2530,6 @@ def warmup(fn, model, example_inputs, mode, niters=10):
25282530
result_summary = latency_experiment_summary(
25292531
self.suite_name, self.args, model, timings, **experiment_kwargs
25302532
)
2531-
if not hasattr(model, name):
2532-
model.name = name
25332533
results.append(result_summary)
25342534
return " ".join(map(str, results))
25352535

@@ -2586,6 +2586,9 @@ def warmup(fn, model, example_inputs, mode, niters=5):
25862586
# Use distributed wrapping as necessary
25872587
model = self.deepcopy_and_maybe_parallelize(model)
25882588

2589+
if not hasattr(model, name):
2590+
model.name = name
2591+
25892592
self.init_optimizer(name, current_device, model.parameters())
25902593

25912594
# The self.autocast context is needed for the model we export with aot_compile,
@@ -2699,8 +2702,6 @@ def warmup(fn, model, example_inputs, mode, niters=5):
26992702
f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
27002703
)
27012704

2702-
if not hasattr(model, name):
2703-
model.name = name
27042705
results.append(experiment(model, example_inputs, **experiment_kwargs))
27052706
return " ".join(map(str, results))
27062707

0 commit comments

Comments
 (0)