Skip to content

Commit 9e12768

Browse files
shino16beverlylytleKaelanDt
authored
Fix KeyError 'i23' with symbolic shapes and reshape bsyms (#2764)
Co-authored-by: Masato Shinokawa <[email protected]> Co-authored-by: beverlylytle <[email protected]> Co-authored-by: KaelanDt <[email protected]>
1 parent ec21d73 commit 9e12768

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

thunder/core/update_aliases.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,23 @@ def replace_args_with_alias_map(
9090
reshaped_arg = arg
9191
if arg_to_replace.shape != arg.shape:
9292
with tracectx(computation_trace):
93-
reshaped_arg = prims.reshape.meta(arg, arg_to_replace.shape)
94-
arg_to_optional_bsyms[variableify(arg_to_replace)] = prims.reshape.bind(
95-
arg,
96-
arg_to_replace.shape,
97-
output=reshaped_arg,
98-
)
93+
shape = prims.shape.meta(arg_to_replace)
94+
reshaped_arg = prims.reshape.meta(arg, shape)
95+
reshape_bsym = prims.reshape.bind(arg, shape, output=reshaped_arg)
96+
if using_symbolic_values():
97+
shape_bsym = prims.shape.bind(arg_to_replace, output=shape)
98+
arg_to_optional_bsyms[variableify(arg_to_replace)] = (shape_bsym, reshape_bsym)
99+
else:
100+
arg_to_optional_bsyms[variableify(arg_to_replace)] = (reshape_bsym,)
99101
swap_map_for_aliases[variableify(arg_to_replace)] = reshaped_arg
100102
appended_bsyms = {}
101103
for bsym in computation_trace.bound_symbols:
102104
for arg in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args):
103-
reshape_bsym = arg_to_optional_bsyms.get(variableify(arg))
104-
if reshape_bsym is not None:
105-
if reshape_bsym not in appended_bsyms:
106-
bsyms.append(reshape_bsym)
107-
appended_bsyms[reshape_bsym] = arg
105+
reshape_bsyms = arg_to_optional_bsyms.get(variableify(arg))
106+
if reshape_bsyms is not None:
107+
if reshape_bsyms not in appended_bsyms:
108+
bsyms.extend(reshape_bsyms)
109+
appended_bsyms[reshape_bsyms] = arg
108110
if replaced_args_map := {
109111
x.name: swap_map_for_aliases[variableify(x)].name
110112
for x in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_args)

thunder/tests/test_update_aliases.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,9 @@ def g(a, b):
328328

329329
@instantiate(
330330
dtypes=NOTHING,
331+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
331332
)
332-
def test_aliased_input(executor, device, dtype):
333+
def test_aliased_input(executor, device, dtype, cache):
333334
def f(x, y, z):
334335
return y.exp_().add(x) + z.exp()
335336

@@ -339,7 +340,7 @@ def f(x, y, z):
339340
a_ = a.clone().detach()
340341
b_ = b.clone().detach()
341342
c_ = c.clone().detach()
342-
jfn = executor.make_callable(f)
343+
jfn = executor.make_callable(f, cache=cache)
343344
actual = jfn(a, b, c)
344345
expected = f(a_, b_, c_)
345346
torch.testing.assert_close(actual, expected)
@@ -350,8 +351,9 @@ def f(x, y, z):
350351

351352
@instantiate(
352353
dtypes=NOTHING,
354+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
353355
)
354-
def test_write_to_intermediate_result(executor, device, dtype):
356+
def test_write_to_intermediate_result(executor, device, dtype, cache):
355357
if executor == nvFuserExecutor:
356358
pytest.xfail("nvFuser does not support writing to intermediate results")
357359

@@ -361,7 +363,7 @@ def fn(x):
361363
return y
362364

363365
a = make_tensor((2, 3), dtype=torch.float32, device=device)
364-
jfn = executor.make_callable(fn, skip_inplace_alias_updates=True)
366+
jfn = executor.make_callable(fn, cache=cache)
365367
actual = jfn(a)
366368
expected = fn(a)
367369
torch.testing.assert_close(actual, expected)
@@ -521,8 +523,9 @@ def foo(x):
521523

522524
@instantiate(
523525
dtypes=(dtypes.float32,),
526+
decorators=(pytest.mark.parametrize("cache", ("constant values", "symbolic values")),),
524527
)
525-
def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype):
528+
def test_aliasing_for_viewed_input_of_different_shapes(executor, device, dtype, cache):
526529
def f(x, y, z):
527530
return x + 2, y.add_(z)
528531

@@ -532,7 +535,7 @@ def f(x, y, z):
532535
a_ = a.clone().detach()
533536
b_ = a_[0, :]
534537
c_ = a_[1, :]
535-
jfn = executor.make_callable(f)
538+
jfn = executor.make_callable(f, cache=cache)
536539
actual = jfn(a, b, c)
537540
expected = f(a_, b_, c_)
538541
torch.testing.assert_close(actual, expected)

0 commit comments

Comments
 (0)