Skip to content

Commit 3926906

Browse files
authored
benchmark_inference: set requires_grad=False on params (#2696)
1 parent 83c3229 commit 3926906

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,6 @@ def __init__(self, config: InferenceBenchmarkConfig):
245245
if mesh:
246246
model = parallelize_module(model, mesh, tp_plan)
247247

248-
# Required as that doesn't understand inference mode
249-
for p in model.parameters():
250-
p.requires_grad_(False)
251-
252248
# Sanity check
253249
if not self.config.disable_moe_replacement:
254250
assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor
@@ -266,6 +262,13 @@ def __init__(self, config: InferenceBenchmarkConfig):
266262
model.to_empty(device=DEVICE)
267263
assert all(p.device == DEVICE for p in model.parameters())
268264

265+
# Required as thunder doesn't understand inference mode
266+
# And some prims like `prims._grouped_mm` don't have grad rule defined yet.
267+
for p in model.parameters():
268+
p.requires_grad_(False)
269+
270+
assert all(not p.requires_grad for p in model.parameters())
271+
269272
# `thunderfx` seems to hide the access to vocab_size somewhere so
270273
# store it here before any compiler is applied.
271274
self.vocab_size = model.vocab_size

0 commit comments

Comments
 (0)