Skip to content

Commit 41ade29

Browse files
yushangdiamathewc
authored andcommitted
Fix with effect lowering for list return type (pytorch#149510)
Summary: - For `torch.ops.higher_order.with_effects`'s lowering, we should not extract the items out of an list (i.e. `*result` vs `result`). The `get_attr` nodes consider the result to be in the list format. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r test_torchbind_aot_compile buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r list_return buck run //caffe2/torch/fb/sparsenn:sigrid_test -- -r test_transform_torch_bind # tested together with D70013257 buck run fbcode//mode/dev-nosan //caffe2/test:test_export -- -r test_custom_obj ``` Reviewed By: angelayi Differential Revision: D71346024 Pull Request resolved: pytorch#149510 Approved by: https://github.com/zou3519
1 parent d7c37c9 commit 41ade29

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

test/inductor/test_torchbind.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,30 @@ def test_torchbind_aot_compile_constant_folding(self):
286286
# TODO: add accuracy test after we support loading and running compiled models with
287287
# torchbind objects.
288288

289+
def test_torchbind_list_return_aot_compile(self):
290+
class M(torch.nn.Module):
291+
def __init__(self) -> None:
292+
super().__init__()
293+
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
294+
295+
def forward(self, x):
296+
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
297+
y = a[0] + a[1] + a[2]
298+
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
299+
return x + b
300+
301+
m = M()
302+
inputs = (torch.ones(2, 3),)
303+
304+
# We can't directly torch.compile because dynamo doesn't trace ScriptObjects yet
305+
with enable_torchbind_tracing():
306+
ep = torch.export.export(m, inputs, strict=False)
307+
308+
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
309+
310+
# TODO: add accuracy test after we support loading and running compiled models with
311+
# torchbind objects.
312+
289313

290314
if __name__ == "__main__":
291315
run_tests()

torch/_higher_order_ops/effects.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ def with_effects_dense(
146146
) -> tuple[torch.Tensor, ...]:
147147
out = op(*args, **kwargs)
148148
new_token = new_token_tensor()
149+
# [NOTE: with_effects return type]
150+
# Note that we should only do *out for tuple type, but not list type.
151+
# This is to match the schema of the op.
152+
# For tuple output, the length of schema output is the same as the length of out.
153+
# For list output, the length of schema output is 1 (e.g. Tensor[]) regardless of the
154+
# length of the list.
149155
if isinstance(out, tuple):
150156
return (new_token, *out)
151157
return (new_token, out)

torch/_inductor/lowering.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6954,7 +6954,9 @@ def with_effects(token, op, *args, **kwargs):
69546954
return (effectful_kernel,)
69556955

69566956
result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
6957-
if not isinstance(result, (list, tuple)):
6957+
# See [NOTE: with_effects return type]
6958+
# Only return `result` if it is a tuple, not list.
6959+
if not isinstance(result, tuple):
69586960
return (effectful_kernel, result)
69596961
else:
69606962
return (effectful_kernel, *result)

0 commit comments

Comments
 (0)