Skip to content

Commit 4b6272c

Browse files
authored
Symbolic values: nn.Parameter to have static shapes even with symbolic values (#2749)
1 parent 71b2e7d commit 4b6272c

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

thunder/core/proxies.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2009,7 +2009,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
20092009
# See Note [DistributedDataParallel and distparallel_type]
20102010
distparallel_type = getattr(t, "distparallel_type", None)
20112011
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
2012-
if using_symbolic_values():
2012+
# For parameters, shapes should be static.
2013+
if using_symbolic_values() and not isinstance(t, torch.nn.Parameter):
20132014
shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance])
20142015
shape = tuple(
20152016
IntegerProxy(

thunder/tests/test_jit_general.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,24 @@ def foo(a, v):
15811581
assert_close(actual, expected)
15821582

15831583

1584+
def test_cache_symbolic_values_nn_parameter_static_shape():
1585+
linear = torch.nn.Linear(2, 2)
1586+
x = torch.randn(2, 2)
1587+
1588+
jlinear = thunder_jit(linear, cache="symbolic values")
1589+
1590+
jlinear(x)
1591+
exec_trc = thunder.last_traces(jlinear)[-1]
1592+
for bsym in exec_trc.bound_symbols:
1593+
if bsym.sym.name == prims.PrimIDs.UNPACK_TRIVIAL and "weight" in bsym.output.name:
1594+
assert bsym.output.shape == (2, 2)
1595+
elif bsym.sym.name == prims.PrimIDs.UNPACK_TRIVIAL and "bias" in bsym.output.name:
1596+
assert bsym.output.shape == (2,)
1597+
elif bsym.sym.name == prims.PrimIDs.UNPACK_TRIVIAL: # Input TensorProxy, this should have symbolic values
1598+
assert isinstance(bsym.output.shape[0], thunder.core.proxies.IntegerProxy)
1599+
assert isinstance(bsym.output.shape[1], thunder.core.proxies.IntegerProxy)
1600+
1601+
15841602
def test_specific_dataclass_returns():
15851603
import transformers
15861604

0 commit comments

Comments
 (0)