Skip to content

Commit 7778a58

Browse files
Revert "[export] Handle kwargs better in aot_export_joint_with_descriptors (pytorch#165334)"
This reverts commit bbb902c. Reverted pytorch#165334 on behalf of https://github.com/jeffdaily due to trunk CI passed here but failures on HUD after merge? test/functorch/test_aot_joint_with_descriptors.py::TestAOTJointWithDescriptors::test_module_with_kwargs [GH job link](https://github.com/pytorch/pytorch/actions/runs/18511729262/job/52755708742) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/bbb902c8dd911e1587253f496c1e2fb178d4b6a1) ([comment](pytorch#165334 (comment)))
1 parent e7091a4 commit 7778a58

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

test/functorch/test_aot_joint_with_descriptors.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -315,58 +315,47 @@ def __init__(self):
315315
super().__init__()
316316
self.linear = nn.Linear(3, 2)
317317

318-
def forward(self, x, *, scale):
318+
def forward(self, x, scale=1.0):
319319
return self.linear(x) * scale
320320

321321
model = ModuleWithKwargs()
322322
inputs = (torch.randn(4, 3),)
323-
kwargs = {"scale": torch.tensor(2.0)}
324-
325-
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
323+
kwargs = {"scale": 2.0}
326324

327325
with ExitStack() as stack:
328326
# Export joint with descriptors
329327
joint_with_descriptors = aot_export_joint_with_descriptors(
330-
stack, gm, inputs, kwargs, decompositions=decomposition_table
328+
stack, model, inputs, kwargs, decompositions=decomposition_table
331329
)
332330

333331
# Test the exported graph structure
334332
graph_code = joint_with_descriptors.graph_module.print_readable(
335333
print_output=False, expanded_def=True
336334
)
337335

338-
# For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
339-
# I tried to fix this in normalize_gm but there are too many files
340-
# depending on that behavior..
341-
graph_code_str = normalize_gm(graph_code)
342-
graph_code_str = "\n".join(
343-
[line for line in graph_code_str.split("\n") if len(line.rstrip()) > 0]
344-
)
345-
346336
# Expect test on the printed graph
347337
self.assertExpectedInline(
348-
graph_code_str,
338+
normalize_gm(graph_code),
349339
"""\
350340
class inner_f(torch.nn.Module):
351341
def forward(
352342
self,
353343
primals,
354344
tangents,
355345
):
356-
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear.weight')
357-
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear.bias')
346+
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
347+
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
358348
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
359-
primals_4: "f32[]" # PlainAOTInput(idx=1)
360349
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
361-
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
350+
primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
362351
transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
363352
mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
364353
mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
365354
mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
366355
broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
367356
add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
368-
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4); add = None
369-
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4); tangents_1 = primals_4 = None
357+
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0); add = None
358+
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0); tangents_1 = None
370359
transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
371360
mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
372361
transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
@@ -376,11 +365,12 @@ def forward(
376365
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
377366
return pytree.tree_unflatten([
378367
mul_2, # PlainAOTOutput(idx=0)
379-
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear.weight'))
380-
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear.bias'))
368+
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
369+
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
381370
None, # None
382371
None, # None
383-
], self._out_spec)""",
372+
], self._out_spec)
373+
""",
384374
)
385375

386376
# Compile the result

torch/_functorch/_aot_autograd/graph_capture_wrappers.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,15 +1342,6 @@ def functional_call(*args, **kwargs):
13421342
maybe_disable_thunkify(),
13431343
):
13441344
if isinstance(mod, torch.fx.GraphModule):
1345-
if kwargs:
1346-
# Handle **kwargs. FX only natively supports positional
1347-
# arguments (through placeholders).
1348-
arg_list = list(args[params_len:])
1349-
arg_list.extend(list(kwargs.values()))
1350-
args = tuple(arg_list)
1351-
else:
1352-
args = args[params_len:]
1353-
13541345
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
13551346
warnings.filterwarnings(
13561347
"ignore", "Anomaly Detection has been enabled."
@@ -1359,7 +1350,9 @@ def functional_call(*args, **kwargs):
13591350
fake_mode = detect_fake_mode()
13601351
assert fake_mode is not None
13611352
fake_mode.epoch += 1
1362-
out = PropagateUnbackedSymInts(mod).run(*args)
1353+
out = PropagateUnbackedSymInts(mod).run(
1354+
*args[params_len:], **kwargs
1355+
)
13631356
else:
13641357
out = mod(*args[params_len:], **kwargs)
13651358

0 commit comments

Comments
 (0)