Skip to content

Commit bbb902c

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Handle kwargs better in aot_export_joint_with_descriptors (pytorch#165334)
fx.Interpreter doesn't handle kwargs... not sure how this code worked previously Pull Request resolved: pytorch#165334 Approved by: https://github.com/tugsbayasgalan, https://github.com/ezyang
1 parent e6f766c commit bbb902c

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

test/functorch/test_aot_joint_with_descriptors.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -315,47 +315,58 @@ def __init__(self):
315315
super().__init__()
316316
self.linear = nn.Linear(3, 2)
317317

318-
def forward(self, x, scale=1.0):
318+
def forward(self, x, *, scale):
319319
return self.linear(x) * scale
320320

321321
model = ModuleWithKwargs()
322322
inputs = (torch.randn(4, 3),)
323-
kwargs = {"scale": 2.0}
323+
kwargs = {"scale": torch.tensor(2.0)}
324+
325+
gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs)
324326

325327
with ExitStack() as stack:
326328
# Export joint with descriptors
327329
joint_with_descriptors = aot_export_joint_with_descriptors(
328-
stack, model, inputs, kwargs, decompositions=decomposition_table
330+
stack, gm, inputs, kwargs, decompositions=decomposition_table
329331
)
330332

331333
# Test the exported graph structure
332334
graph_code = joint_with_descriptors.graph_module.print_readable(
333335
print_output=False, expanded_def=True
334336
)
335337

338+
# For some reason PYTORCH_TEST_WITH_CROSSREF will add extra spaces.
339+
# I tried to fix this in normalize_gm but there are too many files
340+
# depending on that behavior..
341+
graph_code_str = normalize_gm(graph_code)
342+
graph_code_str = "\n".join(
343+
[line for line in graph_code_str.split("\n") if len(line.rstrip()) > 0]
344+
)
345+
336346
# Expect test on the printed graph
337347
self.assertExpectedInline(
338-
normalize_gm(graph_code),
348+
graph_code_str,
339349
"""\
340350
class inner_f(torch.nn.Module):
341351
def forward(
342352
self,
343353
primals,
344354
tangents,
345355
):
346-
primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight')
347-
primals_2: "f32[2]" # ParamAOTInput(target='linear.bias')
356+
primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear.weight')
357+
primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear.bias')
348358
primals_3: "f32[4, 3]" # PlainAOTInput(idx=0)
359+
primals_4: "f32[]" # PlainAOTInput(idx=1)
349360
tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0))
350-
primals_1, primals_2, primals_3, primals_4 , tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
361+
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
351362
transpose: "f32[3, 2]" = torch.ops.prims.transpose.default(primals_1, [1, 0]); primals_1 = None
352363
mm: "f32[4, 2]" = torch.ops.aten.mm.default(primals_3, transpose); transpose = None
353364
mul: "f32[4, 2]" = torch.ops.prims.mul.default(mm, 1.0); mm = None
354365
mul_1: "f32[2]" = torch.ops.prims.mul.default(primals_2, 1.0); primals_2 = None
355366
broadcast_in_dim: "f32[4, 2]" = torch.ops.prims.broadcast_in_dim.default(mul_1, [4, 2], [1]); mul_1 = None
356367
add: "f32[4, 2]" = torch.ops.prims.add.default(mul, broadcast_in_dim); mul = broadcast_in_dim = None
357-
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, 2.0); add = None
358-
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, 2.0); tangents_1 = None
368+
mul_2: "f32[4, 2]" = torch.ops.prims.mul.default(add, primals_4); add = None
369+
mul_3: "f32[4, 2]" = torch.ops.prims.mul.default(tangents_1, primals_4); tangents_1 = primals_4 = None
359370
transpose_1: "f32[2, 4]" = torch.ops.prims.transpose.default(mul_3, [1, 0])
360371
mm_1: "f32[2, 3]" = torch.ops.aten.mm.default(transpose_1, primals_3); transpose_1 = primals_3 = None
361372
transpose_2: "f32[3, 2]" = torch.ops.prims.transpose.default(mm_1, [1, 0]); mm_1 = None
@@ -365,12 +376,11 @@ def forward(
365376
transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None
366377
return pytree.tree_unflatten([
367378
mul_2, # PlainAOTOutput(idx=0)
368-
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight'))
369-
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias'))
379+
transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear.weight'))
380+
as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear.bias'))
370381
None, # None
371382
None, # None
372-
], self._out_spec)
373-
""",
383+
], self._out_spec)""",
374384
)
375385

376386
# Compile the result

torch/_functorch/_aot_autograd/graph_capture_wrappers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,15 @@ def functional_call(*args, **kwargs):
13421342
maybe_disable_thunkify(),
13431343
):
13441344
if isinstance(mod, torch.fx.GraphModule):
1345+
if kwargs:
1346+
# Handle **kwargs. FX only natively supports positional
1347+
# arguments (through placeholders).
1348+
arg_list = list(args[params_len:])
1349+
arg_list.extend(list(kwargs.values()))
1350+
args = tuple(arg_list)
1351+
else:
1352+
args = args[params_len:]
1353+
13451354
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
13461355
warnings.filterwarnings(
13471356
"ignore", "Anomaly Detection has been enabled."
@@ -1350,9 +1359,7 @@ def functional_call(*args, **kwargs):
13501359
fake_mode = detect_fake_mode()
13511360
assert fake_mode is not None
13521361
fake_mode.epoch += 1
1353-
out = PropagateUnbackedSymInts(mod).run(
1354-
*args[params_len:], **kwargs
1355-
)
1362+
out = PropagateUnbackedSymInts(mod).run(*args)
13561363
else:
13571364
out = mod(*args[params_len:], **kwargs)
13581365

0 commit comments

Comments
 (0)