Skip to content

Commit 6096c0f

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Export should use aot_export_joint_with_descriptors (pytorch#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs: 1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer. 2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint. Pull Request resolved: pytorch#165931 Approved by: https://github.com/zhxchen17
1 parent f6951cb commit 6096c0f

File tree

7 files changed

+167
-56
lines changed

7 files changed

+167
-56
lines changed

test/export/test_export.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13910,16 +13910,28 @@ def forward(self, x):
1391013910
inps = (torch.ones(5),)
1391113911

1391213912
ep = torch.export.export(M(), inps).run_decompositions({})
13913-
self.assertExpectedInline(
13914-
str(ep.graph_module.code.strip()),
13915-
"""\
13913+
if IS_FBCODE:
13914+
self.assertExpectedInline(
13915+
str(ep.graph_module.code.strip()),
13916+
"""\
1391613917
def forward(self, x):
1391713918
cos = torch.ops.aten.cos.default(x)
1391813919
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None
1391913920
getitem_3 = auto_functionalized[3]; auto_functionalized = None
1392013921
cos_1 = torch.ops.aten.cos.default(getitem_3)
1392113922
return (getitem_3, getitem_3, cos_1)""",
13922-
)
13923+
)
13924+
else:
13925+
self.assertExpectedInline(
13926+
str(ep.graph_module.code.strip()),
13927+
"""\
13928+
def forward(self, x):
13929+
cos = torch.ops.aten.cos.default(x)
13930+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [x, cos]); x = cos = None
13931+
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
13932+
cos_1 = torch.ops.aten.cos.default(getitem_3)
13933+
return (getitem_3, getitem_3, cos_1)""",
13934+
)
1392313935

1392413936
def test_custom_op_auto_warn_pre_dispatch(self):
1392513937
class M(torch.nn.Module):
@@ -13932,17 +13944,30 @@ def forward(self, x):
1393213944
inps = (torch.ones(5),)
1393313945

1393413946
ep = torch.export.export(M(), inps).run_decompositions()
13935-
self.assertExpectedInline(
13936-
str(ep.graph_module.code.strip()),
13937-
"""\
13947+
if IS_FBCODE:
13948+
self.assertExpectedInline(
13949+
str(ep.graph_module.code.strip()),
13950+
"""\
1393813951
def forward(self, x):
1393913952
cos = torch.ops.aten.cos.default(x)
1394013953
cos_1 = torch.ops.aten.cos.default(x); x = None
1394113954
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1); cos = cos_1 = None
1394213955
getitem_3 = auto_functionalized[3]; auto_functionalized = None
1394313956
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
1394413957
return (cos_2,)""",
13945-
)
13958+
)
13959+
else:
13960+
self.assertExpectedInline(
13961+
str(ep.graph_module.code.strip()),
13962+
"""\
13963+
def forward(self, x):
13964+
cos = torch.ops.aten.cos.default(x)
13965+
cos_1 = torch.ops.aten.cos.default(x); x = None
13966+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [cos, cos_1]); cos = cos_1 = None
13967+
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
13968+
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
13969+
return (cos_2,)""",
13970+
)
1394613971

1394713972
ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
1394813973
self.assertExpectedInline(
@@ -15338,17 +15363,30 @@ def forward(self, x):
1533815363
decomp_table,
1533915364
)
1534015365

15341-
self.assertExpectedInline(
15342-
str(ep.graph_module.code).strip(),
15343-
"""\
15366+
if IS_FBCODE:
15367+
self.assertExpectedInline(
15368+
str(ep.graph_module.code).strip(),
15369+
"""\
1534415370
def forward(self, x):
1534515371
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
1534615372
cos = torch.ops.aten.cos.default(foo_functional)
1534715373
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = foo_functional, z = cos); foo_functional = cos = None
1534815374
getitem_3 = auto_functionalized[3]; auto_functionalized = None
1534915375
cos_1 = torch.ops.aten.cos.default(getitem_3)
1535015376
return (getitem_3, cos_1)""",
15351-
)
15377+
)
15378+
else:
15379+
self.assertExpectedInline(
15380+
str(ep.graph_module.code).strip(),
15381+
"""\
15382+
def forward(self, x):
15383+
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
15384+
cos = torch.ops.aten.cos.default(foo_functional)
15385+
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [foo_functional, cos]); foo_functional = cos = None
15386+
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
15387+
cos_1 = torch.ops.aten.cos.default(getitem_3)
15388+
return (getitem_3, cos_1)""",
15389+
)
1535215390

1535315391
def test_run_decompositions_keep_metadata(self):
1535415392
"""Make sure the metadata is kept after exported program run_decompositions."""

torch/_export/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
from typing import Any, TYPE_CHECKING
1212

13+
from torch._environment import is_fbcode
1314
from torch.utils._config_module import install_config_module
1415

1516

@@ -27,6 +28,11 @@
2728
# that we don't know how to proxy, resulting in untracked fake tensors
2829
error_on_lifted_constant_tensors = True
2930

31+
# enable auto_functionalized_v2 in export
32+
# We turn this off in fbcode due to downstream users not
33+
# being ready to handle auto_functionalized_v2.
34+
enable_auto_functionalized_v2_for_export = not is_fbcode()
35+
3036
if TYPE_CHECKING:
3137
from torch.utils._config_typing import * # noqa: F401, F403
3238

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,6 @@ def run_functionalized_fw_and_collect_metadata(
166166
# Note: this is guaranteed to be set when running under dynamo
167167
static_input_indices: Optional[list[int]] = None,
168168
pre_dispatch: bool = False,
169-
# is_export is technically only needed to avoid using functionalization V2
170-
# during analysis
171-
is_export: bool = False,
172169
) -> Callable[..., ViewAndMutationMeta]:
173170
memo: dict[Tensor, Tensor] = {}
174171

@@ -200,7 +197,7 @@ def inner(*flat_args):
200197

201198
# It doesn't matter if we run this under predispatch or not because it is
202199
# only for figuring out metadata
203-
mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export)
200+
mode = FunctionalTensorMode(_allow_token_discovery=True)
204201
suppress_pending = contextlib.nullcontext()
205202
fake_mode = detect_fake_mode()
206203
if fake_mode and (shape_env := fake_mode.shape_env):

torch/_functorch/_aot_autograd/frontend_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def convert(idx, x):
4141
return x
4242
source = ConstantSource(f"sym_{idx}")
4343
return shape_env.create_symintnode(
44-
shape_env.create_symbol(x, source),
44+
shape_env.create_symbol(x, source, positive=x >= 0),
4545
hint=x,
4646
source=source,
4747
)

torch/_functorch/aot_autograd.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ def _dup_fake_script_obj(fake_flat_args):
573573
keep_input_mutations=aot_config.keep_inference_input_mutations,
574574
is_train=needs_autograd,
575575
pre_dispatch=aot_config.pre_dispatch,
576-
is_export=aot_config.is_export,
577576
)(*_dup_fake_script_obj(fake_flat_args))
578577

579578
req_subclass_dispatch = requires_subclass_dispatch(
@@ -905,6 +904,7 @@ def prepare_aot_module_simplified(
905904
*,
906905
force_non_lazy_backward_lowering: bool = False,
907906
disable_functionalization: bool = False,
907+
_record_nn_module_stack: bool = False,
908908
):
909909
if not flatten:
910910
assert kwargs is None
@@ -931,7 +931,13 @@ def prepare_aot_module_simplified(
931931
# NB: This doesn't change the in/out convention, except adding the
932932
# parameters as explicit arguments
933933
functional_call = create_functional_call(
934-
mod, params_buffers_spec, params_len + buffers_len, strict_out_tuple=not flatten
934+
mod,
935+
params_buffers_spec,
936+
params_len + buffers_len,
937+
strict_out_tuple=not flatten,
938+
# We need this for export to run ModuleStackTracer
939+
# instead of PythonKeyTracer
940+
store_orig_mod=_record_nn_module_stack,
935941
)
936942

937943
full_args = [*params_flat, *buffers_flat, *args]
@@ -1175,6 +1181,7 @@ def aot_export_joint_with_descriptors(
11751181
keep_inference_input_mutations=False,
11761182
ignore_shape_env=False,
11771183
disable_functionalization=False,
1184+
_record_nn_module_stack=False,
11781185
) -> JointWithDescriptors:
11791186
"""
11801187
This API captures the joint graph for an nn.Module. However, unlike
@@ -1265,6 +1272,7 @@ def aot_export_joint_with_descriptors(
12651272
# context.
12661273
force_non_lazy_backward_lowering=True,
12671274
disable_functionalization=disable_functionalization,
1275+
_record_nn_module_stack=_record_nn_module_stack,
12681276
)
12691277

12701278
# TODO: Maybe this should be in create_aot_state? Not sure, that would

torch/_subclasses/functional_tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __new__(cls, elem, mode):
145145
out.elem = elem
146146

147147
if (
148-
not mode.export
148+
torch._export.config.enable_auto_functionalized_v2_for_export
149149
and torch.is_inference_mode_enabled()
150150
and torch._inductor.config.enable_auto_functionalized_v2
151151
):
@@ -449,12 +449,18 @@ def unwrap(x):
449449
) and not torch._C._dispatch_has_kernel_for_dispatch_key(
450450
func.name(), torch._C.DispatchKey.Functionalize
451451
):
452+
import torch._export.config as export_config
452453
import torch._inductor.config as inductor_config
453454

454-
if self.export or not inductor_config.enable_auto_functionalized_v2:
455+
if torch.compiler.is_exporting():
456+
if export_config.enable_auto_functionalized_v2_for_export:
457+
return do_auto_functionalize_v2(self, func, args, kwargs)
458+
455459
return do_auto_functionalize(self, func, args, kwargs)
456-
else:
460+
461+
if inductor_config.enable_auto_functionalized_v2:
457462
return do_auto_functionalize_v2(self, func, args, kwargs)
463+
return do_auto_functionalize(self, func, args, kwargs)
458464

459465
from torch._higher_order_ops.effects import handle_effects, has_effects
460466

0 commit comments

Comments
 (0)