|
1 | 1 | from __future__ import annotations |
2 | 2 | from collections.abc import Callable, Sequence |
| 3 | +from contextlib import contextmanager |
3 | 4 | from enum import Enum, auto |
4 | 5 | from typing import TYPE_CHECKING |
5 | 6 | import dataclasses |
|
8 | 9 | from types import NoneType |
9 | 10 | from collections import defaultdict |
10 | 11 | from collections import namedtuple |
| 12 | +import warnings |
| 13 | + |
| 14 | +from looseversion import LooseVersion |
11 | 15 |
|
12 | 16 | import torch |
13 | 17 | from torch.nn.modules.module import _addindent |
14 | 18 | from torch.utils.weak import TensorWeakRef |
15 | | -from torch._guards import detect_fake_mode |
16 | | -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor |
| 19 | +from torch._guards import tracing, TracingContext |
| 20 | +from torch._subclasses.fake_tensor import DynamicOutputShapeException |
17 | 21 |
|
18 | 22 | if torch.distributed.is_available(): |
19 | 23 | from torch.distributed.tensor import DTensor |
@@ -154,6 +158,56 @@ def is_thunder_supported_partition(self, node: torch.fx.Node) -> bool: |
154 | 158 | return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in self.supported_indexes |
155 | 159 |
|
156 | 160 |
|
| 161 | +class LazyInductorModule(torch.nn.Module): |
| 162 | + def __init__(self, graph_module, fake_mode): |
| 163 | + super().__init__() |
| 164 | + self.graph_module = graph_module |
| 165 | + self.compiled_fn = None |
| 166 | + self.fake_mode = fake_mode |
| 167 | + |
| 168 | + # For ease of debugging, we add graph attribute so GraphModule.print_readable will print it |
| 169 | + self.graph = graph_module.graph |
| 170 | + |
| 171 | + # TODO: Remove this once we drop support for PyTorch 2.7.x |
| 172 | + @contextmanager |
| 173 | + def _maybe_patch_increment_toplevel(self): |
| 174 | + # In PyTorch before 2.8.0, FXGraphCache assumes that it is run behind Dynamo |
| 175 | + # and tries to update metrics_context. |
| 176 | + # See https://github.com/pytorch/pytorch/pull/150423 |
| 177 | + if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"): |
| 178 | + yield |
| 179 | + return |
| 180 | + |
| 181 | + from torch._dynamo.utils import CompileEventLogger |
| 182 | + |
| 183 | + def fake_increment_toplevel(*args, **kwargs): |
| 184 | + metrics_context = torch._dynamo.utils.get_metrics_context() |
| 185 | + assert not metrics_context.in_progress() |
| 186 | + return |
| 187 | + |
| 188 | + original = CompileEventLogger.increment_toplevel |
| 189 | + CompileEventLogger.increment_toplevel = fake_increment_toplevel |
| 190 | + try: |
| 191 | + yield |
| 192 | + finally: |
| 193 | + CompileEventLogger.increment_toplevel = original |
| 194 | + |
| 195 | + def forward(self, *args): |
| 196 | + if self.compiled_fn is None: |
| 197 | + with self._maybe_patch_increment_toplevel(): |
| 198 | + # Inductor needs fake_mode, particularly its shape_env, to handle SymInts |
| 199 | + with tracing(TracingContext(fake_mode=self.fake_mode)): |
| 200 | + try: |
| 201 | + self.compiled_fn = torch._inductor.compile(self.graph_module, args) |
| 202 | + except DynamicOutputShapeException as e: |
| 203 | + # This exception is meant to be handled by Dynamo, which is responsible for graph break |
| 204 | + # TODO: Use torch.compile for fallback. Ensure its correctness. |
| 205 | + warnings.warn(f"Dynamic output shape operator encountered: {e}. Falling back to eager.") |
| 206 | + self.compiled_fn = self.graph_module |
| 207 | + |
| 208 | + return self.compiled_fn(*args) |
| 209 | + |
| 210 | + |
157 | 211 | @dataclasses.dataclass() |
158 | 212 | class ProfileStats: |
159 | 213 | """ |
@@ -1064,25 +1118,6 @@ def get_compiled_fn_and_timing(report, compile_fn, timer_fn): |
1064 | 1118 | return sorted_compiled_gm_to_measurement[0].compiled_fn |
1065 | 1119 |
|
1066 | 1120 |
|
1067 | | -def make_fake_arguments(gm: torch.fx.GraphModule) -> list[FakeTensor] | None: |
1068 | | - fake_mode = detect_fake_mode() |
1069 | | - if fake_mode is None: |
1070 | | - fake_mode = FakeTensorMode() |
1071 | | - args = [] |
1072 | | - for node in gm.graph.nodes: |
1073 | | - if node.op == "placeholder": |
1074 | | - meta_val = node.meta.get("example_value") |
1075 | | - if meta_val is None: |
1076 | | - # We observed Dynamo creating nodes without `example_value` on Tensor.tolist(). |
1077 | | - # This no longer happens in PyTorch 2.10 (see https://github.com/pytorch/pytorch/pull/163807). |
1078 | | - return None |
1079 | | - if isinstance(meta_val, torch.Tensor): |
1080 | | - # Tie to the currently enabled fake mode |
1081 | | - meta_val = fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, meta_val) |
1082 | | - args.append(meta_val) |
1083 | | - return args |
1084 | | - |
1085 | | - |
1086 | 1121 | def translate_dtensor_ops(gm: torch.fx.GraphModule) -> None: |
1087 | 1122 | # We need this function because: |
1088 | 1123 | # |
|
0 commit comments