You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Export should use aot_export_joint_with_descriptors (pytorch#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.
Pull Request resolved: pytorch#165931
Approved by: https://github.com/zhxchen17
ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
13948
13973
self.assertExpectedInline(
@@ -15338,17 +15363,30 @@ def forward(self, x):
15338
15363
decomp_table,
15339
15364
)
15340
15365
15341
-
self.assertExpectedInline(
15342
-
str(ep.graph_module.code).strip(),
15343
-
"""\
15366
+
if IS_FBCODE:
15367
+
self.assertExpectedInline(
15368
+
str(ep.graph_module.code).strip(),
15369
+
"""\
15344
15370
def forward(self, x):
15345
15371
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
15346
15372
cos = torch.ops.aten.cos.default(foo_functional)
15347
15373
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = foo_functional, z = cos); foo_functional = cos = None
0 commit comments