|
45 | 45 | "dims", |
46 | 46 | "export", |
47 | 47 | "export_for_training", |
48 | | - "export_for_inference", |
49 | 48 | "load", |
50 | 49 | "register_dataclass", |
51 | 50 | "save", |
@@ -167,91 +166,6 @@ def export_for_training( |
167 | 166 | ) |
168 | 167 |
|
169 | 168 |
|
170 | | -def export_for_inference( |
171 | | - mod: torch.nn.Module, |
172 | | - args: tuple[Any, ...], |
173 | | - kwargs: Optional[dict[str, Any]] = None, |
174 | | - *, |
175 | | - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, |
176 | | - strict: bool = True, |
177 | | - preserve_module_call_signature: tuple[str, ...] = (), |
178 | | - decomp_table: Optional[dict["OpOverload", Optional[Callable]]] = None, |
179 | | -) -> ExportedProgram: |
180 | | - """ |
181 | | - :func:`export_for_inference` takes any nn.Module along with example inputs, and produces a traced graph representing |
182 | | - only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, |
183 | | - which can subsequently be executed with different inputs or serialized. The |
184 | | - traced graph (1) produces normalized operators in the ATen operator set |
185 | | - (as well as any user-specified custom operators) which is customizable via decomp_table, |
186 | | - (2) has eliminated all Python control flow and data structures (with certain exceptions), |
187 | | - and (3) records the set of shape constraints needed to show that this normalization and control-flow |
188 | | - elimination is sound for future inputs. This API is for convenience use as it combines :func:`export_for_training` and |
189 | | - :func:`run_decompositions`. |
190 | | -
|
191 | | - **Soundness Guarantee** |
192 | | -
|
193 | | - See :func:`export()` docstring for more details. |
194 | | -
|
195 | | - Args: |
196 | | - mod: We will trace the forward method of this module. |
197 | | -
|
198 | | - args: Example positional inputs. |
199 | | -
|
200 | | - kwargs: Optional example keyword inputs. |
201 | | -
|
202 | | - dynamic_shapes: |
203 | | - An optional argument where the type should either be: |
204 | | - 1) a dict from argument names of ``f`` to their dynamic shape specifications, |
205 | | - 2) a tuple that specifies dynamic shape specifications for each input in original order. |
206 | | - If you are specifying dynamism on keyword args, you will need to pass them in the order that |
207 | | - is defined in the original function signature. |
208 | | -
|
209 | | - The dynamic shape of a tensor argument can be specified as either |
210 | | - (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is |
211 | | - not required to include static dimension indices in this dict, but when they are, |
212 | | - they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, |
213 | | - where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions |
214 | | - are denoted by None. Arguments that are dicts or tuples / lists of tensors are |
215 | | - recursively specified by using mappings or sequences of contained specifications. |
216 | | -
|
217 | | - strict: When enabled (default), the export function will trace the program through |
218 | | - TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the |
219 | | - exported program will not validate the implicit assumptions baked into the graph and |
220 | | - may cause behavior divergence between the original model and the exported one. This is |
221 | | - useful when users need to workaround bugs in the tracer, or simply want incrementally |
222 | | - enable safety in their models. Note that this does not affect the resulting IR spec |
223 | | - to be different and the model will be serialized in the same way regardless of what value |
224 | | - is passed here. |
225 | | - WARNING: This option is experimental and use this at your own risk. |
226 | | -
|
227 | | - decomp_table: See :func:`run_decompositions` for more details. |
228 | | -
|
229 | | - Returns: |
230 | | - An :class:`ExportedProgram` containing the traced callable. |
231 | | -
|
232 | | - **Acceptable input/output types** |
233 | | -
|
234 | | - Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: |
235 | | -
|
236 | | - - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. |
237 | | - - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. |
238 | | - - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and |
239 | | - ``OrderedDict`` containing all above types. |
240 | | -
|
241 | | - """ |
242 | | - |
243 | | - ep_for_training = export_for_training( |
244 | | - mod, |
245 | | - args, |
246 | | - kwargs, |
247 | | - dynamic_shapes=dynamic_shapes, |
248 | | - strict=strict, |
249 | | - preserve_module_call_signature=preserve_module_call_signature, |
250 | | - ) |
251 | | - |
252 | | - return ep_for_training.run_decompositions(decomp_table=decomp_table) |
253 | | - |
254 | | - |
255 | 169 | def export( |
256 | 170 | mod: torch.nn.Module, |
257 | 171 | args: tuple[Any, ...], |
|
0 commit comments