Skip to content

Commit e9980dc

Browse files
gmagogsfmamathewc
authored andcommitted
Remove torch.export.export_for_inference (pytorch#149078)
Summary: Remove torch.export.export_for_inference, it is redundant and can always be replaced with torch.export.export_for_training() + run_decompositions() Test Plan: unit tests Differential Revision: D71069057 Pull Request resolved: pytorch#149078 Approved by: https://github.com/tugsbayasgalan
1 parent 41ade29 commit e9980dc

File tree

2 files changed

+0
-136
lines changed

2 files changed

+0
-136
lines changed

test/export/test_export.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7000,56 +7000,6 @@ def forward(self, x):
70007000
self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp)))
70017001
self.assertEqual(id(state_dict), id(ep.state_dict))
70027002

7003-
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
7004-
def test_export_for_inference_e2e(self):
7005-
class M(torch.nn.Module):
7006-
def __init__(self) -> None:
7007-
super().__init__()
7008-
self.lin = torch.nn.Linear(10, 1)
7009-
7010-
def forward(self, x):
7011-
return self.lin(x)
7012-
7013-
inp = (torch.randn(5, 10),)
7014-
m = M()
7015-
7016-
decomp_table = torch.export.default_decompositions()
7017-
7018-
def _custom_decomp_for_linear(x, weight, bias):
7019-
return x + bias.sum()
7020-
7021-
decomp_table[torch.ops.aten.linear.default] = _custom_decomp_for_linear
7022-
del decomp_table[torch.ops.aten.sum.default]
7023-
ep = torch.export.export_for_inference(
7024-
m, inp, decomp_table=decomp_table, dynamic_shapes={"x": {0: Dim("batch")}}
7025-
)
7026-
7027-
self.assertExpectedInline(
7028-
str(ep.graph_module.code).strip(),
7029-
"""\
7030-
def forward(self, p_lin_weight, p_lin_bias, x):
7031-
sum_1 = torch.ops.aten.sum.default(p_lin_bias); p_lin_bias = None
7032-
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
7033-
return (add,)""",
7034-
)
7035-
7036-
ep_core = ep.run_decompositions()
7037-
7038-
self.assertExpectedInline(
7039-
str(ep_core.graph_module.code).strip(),
7040-
"""\
7041-
def forward(self, p_lin_weight, p_lin_bias, x):
7042-
sum_1 = torch.ops.aten.sum.dim_IntList(p_lin_bias, []); p_lin_bias = None
7043-
add = torch.ops.aten.add.Tensor(x, sum_1); x = sum_1 = None
7044-
return (add,)""",
7045-
)
7046-
7047-
with self.assertRaisesRegex(RuntimeError, "Expected input"):
7048-
ep.module()(torch.randn(4, 12))
7049-
7050-
with self.assertRaisesRegex(RuntimeError, "Expected input"):
7051-
ep_core.module()(torch.randn(4, 12))
7052-
70537003
@unittest.skipIf(IS_FBCODE, "We can't customize decomp in fbcode")
70547004
def test_export_decomp_torture_case_1(self):
70557005
class M(torch.nn.Module):

torch/export/__init__.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
"dims",
4646
"export",
4747
"export_for_training",
48-
"export_for_inference",
4948
"load",
5049
"register_dataclass",
5150
"save",
@@ -167,91 +166,6 @@ def export_for_training(
167166
)
168167

169168

170-
def export_for_inference(
171-
mod: torch.nn.Module,
172-
args: tuple[Any, ...],
173-
kwargs: Optional[dict[str, Any]] = None,
174-
*,
175-
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
176-
strict: bool = True,
177-
preserve_module_call_signature: tuple[str, ...] = (),
178-
decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None,
179-
) -> ExportedProgram:
180-
"""
181-
:func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing
182-
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
183-
which can subsequently be executed with different inputs or serialized. The
184-
traced graph (1) produces normalized operators in the ATen operator set
185-
(as well as any user-specified custom operators) which is customizable via decomp_table,
186-
(2) has eliminated all Python control flow and data structures (with certain exceptions),
187-
and (3) records the set of shape constraints needed to show that this normalization and control-flow
188-
elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and
189-
:func:`run_decompositions`.
190-
191-
**Soundness Guarantee**
192-
193-
See :func:`export()` docstring for more details.
194-
195-
Args:
196-
mod: We will trace the forward method of this module.
197-
198-
args: Example positional inputs.
199-
200-
kwargs: Optional example keyword inputs.
201-
202-
dynamic_shapes:
203-
An optional argument where the type should either be:
204-
1) a dict from argument names of ``f`` to their dynamic shape specifications,
205-
2) a tuple that specifies dynamic shape specifications for each input in original order.
206-
If you are specifying dynamism on keyword args, you will need to pass them in the order that
207-
is defined in the original function signature.
208-
209-
The dynamic shape of a tensor argument can be specified as either
210-
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
211-
not required to include static dimension indices in this dict, but when they are,
212-
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
213-
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
214-
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
215-
recursively specified by using mappings or sequences of contained specifications.
216-
217-
strict: When enabled (default), the export function will trace the program through
218-
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
219-
exported program will not validate the implicit assumptions baked into the graph and
220-
may cause behavior divergence between the original model and the exported one. This is
221-
useful when users need to workaround bugs in the tracer, or simply want incrementally
222-
enable safety in their models. Note that this does not affect the resulting IR spec
223-
to be different and the model will be serialized in the same way regardless of what value
224-
is passed here.
225-
WARNING: This option is experimental and use this at your own risk.
226-
227-
decomp_table: See :func:`run_decompositions` for more details.
228-
229-
Returns:
230-
An :class:`ExportedProgram` containing the traced callable.
231-
232-
**Acceptable input/output types**
233-
234-
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
235-
236-
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
237-
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
238-
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
239-
``OrderedDict`` containing all above types.
240-
241-
"""
242-
243-
ep_for_training = export_for_training(
244-
mod,
245-
args,
246-
kwargs,
247-
dynamic_shapes=dynamic_shapes,
248-
strict=strict,
249-
preserve_module_call_signature=preserve_module_call_signature,
250-
)
251-
252-
return ep_for_training.run_decompositions(decomp_table=decomp_table)
253-
254-
255169
def export(
256170
mod: torch.nn.Module,
257171
args: tuple[Any, ...],

0 commit comments

Comments
 (0)