Skip to content

Commit b3e435e

Browse files
mattteochenpre-commit-ci[bot]kiya00
authored
Supported kwargs in ThunderCompiler (#2759)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yan Wang <[email protected]>
1 parent 817c218 commit b3e435e

File tree

4 files changed

+65
-4
lines changed

4 files changed

+65
-4
lines changed

thunder/dynamo/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(self, **thunder_options):
129129
)
130130
self.thunder_options = thunder_options
131131

132-
def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]):
132+
def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor], **compile_options):
133133
from thunder import jit
134134

135135
remove_empty_autocast(gm)
@@ -148,6 +148,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
148148
gm,
149149
partial(jit, **thunder_options),
150150
thunder_options,
151+
**compile_options,
151152
)
152153
self.subgraph_infos.append(subgraph_info)
153154
return split_module

thunder/dynamo/splitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def _splitter(
3434
gm: torch.fx.GraphModule,
3535
thunder_jit: Callable,
3636
thunder_options: dict[str, Any] | None = None,
37+
**compile_options,
3738
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
3839
"""
3940
This method will split graph into multiple graph modules based on thunder supported operations.
@@ -225,7 +226,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool:
225226
fake_mode = torch._guards.detect_fake_mode()
226227
# Delay Inductor compilation until invocation with real tensors,
227228
# because we do not know the strides of tensors that Thunder-compiled submodules return.
228-
jit_fn = LazyInductorModule(graph_module, fake_mode)
229+
jit_fn = LazyInductorModule(graph_module, fake_mode, **compile_options)
229230

230231
# Update the node name from "submod_*" to "inductor_*" for more user-friendly names
231232
update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn)

thunder/dynamo/utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch._guards import tracing, TracingContext
2121
from torch._subclasses.fake_tensor import DynamicOutputShapeException
2222

23+
from torch._inductor import list_mode_options
24+
2325
if torch.distributed.is_available():
2426
from torch.distributed.tensor import DTensor
2527
else:
@@ -160,11 +162,12 @@ def is_thunder_supported_partition(self, node: torch.fx.Node) -> bool:
160162

161163

162164
class LazyInductorModule(torch.nn.Module):
163-
def __init__(self, graph_module, fake_mode):
165+
def __init__(self, graph_module, fake_mode, **compile_options):
164166
super().__init__()
165167
self.graph_module = graph_module
166168
self.compiled_fn = None
167169
self.fake_mode = fake_mode
170+
self.compile_options = compile_options
168171

169172
# For ease of debugging, we add graph attribute so GraphModule.print_readable will print it
170173
self.graph = graph_module.graph
@@ -200,7 +203,14 @@ def forward(self, *args):
200203
with tracing(TracingContext(fake_mode=self.fake_mode)):
201204
try:
202205
original_graph = copy.deepcopy(self.graph_module.graph)
203-
self.compiled_fn = torch._inductor.compile(self.graph_module, args)
206+
# Extract and merge options from compile_options
207+
options = self.compile_options.get("options", {}).copy()
208+
mode = self.compile_options.get("mode")
209+
if mode:
210+
mode_options = list_mode_options().get(mode, {})
211+
options.update(mode_options)
212+
213+
self.compiled_fn = torch._inductor.compile(self.graph_module, args, options=options)
204214
except DynamicOutputShapeException as e:
205215
# This exception is meant to be handled by Dynamo, which is responsible for graph break
206216
# TODO: Use torch.compile for fallback. Ensure its correctness.

thunder/tests/test_dynamo.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,55 @@ def func(x):
180180
assert mock_inductor.call_count == 2
181181

182182

183+
@instantiate(
184+
dtypes=NOTHING,
185+
executors=[DynamoThunderExecutor],
186+
decorators=(
187+
pytest.mark.parametrize("mode", ("default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs")),
188+
pytest.mark.skipif(
189+
condition=IS_WINDOWS,
190+
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
191+
),
192+
),
193+
)
194+
def test_kwargs_forwarding_to_inductor(executor, device, dtype, mode):
195+
"""Test that torch.compile kwargs (like mode) are forwarded to inductor for unsupported regions."""
196+
from torch._inductor import list_mode_options
197+
198+
x = torch.randn(2, 2, device=device, dtype=dtype)
199+
200+
def func(x):
201+
# torch.sinc has automatic fallback registered,
202+
# so that operation will be given to inductor.
203+
return x.sinc()
204+
205+
cfunc = thunderfx(func, mode=mode)
206+
207+
# Mock torch._inductor.compile to verify arguments are passed correctly
208+
original_compile = torch._inductor.compile
209+
210+
with patch("torch._inductor.compile", side_effect=original_compile) as mock_compile:
211+
cfunc(x)
212+
213+
# Verify the mock was called (inductor fallback occurred)
214+
assert mock_compile.called
215+
216+
# Get the kwargs from the call
217+
_, compile_options = mock_compile.call_args
218+
219+
# Check if mode was expanded into options
220+
expected_options = list_mode_options().get(mode, {})
221+
222+
# At least empty dict options should be passed
223+
assert "options" in compile_options
224+
225+
if expected_options:
226+
# If the mode has options, verify they were passed
227+
options = compile_options["options"]
228+
for k, v in expected_options.items():
229+
assert options.get(k) == v, f"Expected option {k}={v} for mode {mode}, but got {options.get(k)}"
230+
231+
183232
@instantiate(
184233
dtypes=NOTHING,
185234
executors=[DynamoThunderExecutor],

0 commit comments

Comments
 (0)