Skip to content

Commit ec045fe

Browse files
committed
Propagate disable_torch_autograd to thunderfx's _splitter
For example, when there's a `torch.library.custom_op` that's registered to Thunder without backward definition, that custom_op could result in a graph split. This is because splitter only sees input proxies `requires_grad` -- https://github.com/Lightning-AI/lightning-thunder/blob/280c57ec289d09eacad0a29cf97a332e02ecebaa/thunder/dynamo/utils.py#L312-L322 Signed-off-by: Masaki Kozuki <[email protected]>
1 parent d001dbe commit ec045fe

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

thunder/dynamo/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
150150
partial(jit, **thunder_options),
151151
self._torch_compile,
152152
sample_args,
153+
thunder_options,
153154
)
154155
self.subgraph_infos.append(subgraph_info)
155156
return split_module

thunder/dynamo/splitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _splitter(
3030
thunder_jit: Callable,
3131
torch_inductor: Callable,
3232
_unused_sample_args: list[torch.SymInt, torch.Tensor],
33+
thunder_options: dict[str, Any],
3334
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
3435
"""
3536
This method will split graph into multiple graph modules based on thunder supported operations.
@@ -119,7 +120,7 @@ def callback(node) -> int:
119120
)
120121
split_reasons.append(split_reason)
121122
else:
122-
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
123+
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
123124
if split_reason is not None:
124125
split_reasons.append(split_reason)
125126

thunder/dynamo/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def make_input_proxy(arg_node):
243243
return proxy_args, proxy_kwargs
244244

245245

246-
def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
246+
def try_execute_thunder_symbol(
247+
thunder_symbol: Symbol, node: torch.fx.Node, thunder_options: dict[str, Any]
248+
) -> tuple[bool, SplitReason | None]:
247249
"""
248250
Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments.
249251
@@ -287,6 +289,7 @@ def get_requires_grad(arg_node):
287289

288290
args, _ = tree_flatten((node.args, node.kwargs))
289291
requires_grad = any(map(get_requires_grad, args))
292+
disable_torch_autograd: bool | None = thunder_options.get("disable_torch_autograd", None)
290293

291294
@compile_data_and_stats(cd, cs)
292295
@thunder._with_cache_info_ctx
@@ -309,7 +312,12 @@ def _run_with_cache_info():
309312
exception=str(e),
310313
)
311314

312-
function_to_run = value_and_grad(thunder_symbol) if requires_grad else thunder_symbol
315+
function_to_run = thunder_symbol
316+
function_to_run = (
317+
value_and_grad(thunder_symbol)
318+
if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd)
319+
else thunder_symbol
320+
)
313321
# We need to be under trace context to generate proxies.
314322
with thunder.core.trace.tracectx(TraceCtx()):
315323
try:
@@ -351,7 +359,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
351359
return nodes_in_unsupported_ctx_regions
352360

353361

354-
def is_graphmodule_supported_by_thunder(gm):
362+
def is_graphmodule_supported_by_thunder(gm, thunder_options: dict[str, Any]):
355363
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
356364
for node in gm.graph.nodes:
357365
if node.op in (
@@ -367,13 +375,15 @@ def is_graphmodule_supported_by_thunder(gm):
367375
)
368376
return False, split_reason
369377

370-
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
378+
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
371379
if not is_thunder_supported:
372380
return False, split_reason
373381
return True, None
374382

375383

376-
def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
384+
def is_node_supported_by_thunder(
385+
node: torch.fx.Node, thunder_options: dict[str, Any]
386+
) -> tuple[bool, SplitReason | None]:
377387
"""
378388
Determine whether thunder can execute the operation described by this node.
379389
"""
@@ -425,7 +435,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
425435
for arg_node in node.args:
426436
if arg_node.op == "get_attr":
427437
called_module = getattr(m, arg_node.target)
428-
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module)
438+
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module, thunder_options)
429439
if not is_module_supported:
430440
return is_module_supported, split_reason
431441
return True, None
@@ -438,7 +448,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
438448
# We try to proxify the arguments and call these operations on them to see if they are supported.
439449
if target in _torch_to_thunder_function_map or inspect.isbuiltin(target):
440450
thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target)
441-
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node)
451+
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node, thunder_options)
442452
return did_run, opt_split_reason
443453

444454
# There are few operations which are registered only as method in `torchctx` and hence they don't exist
@@ -457,7 +467,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
457467
)
458468
# NOTE: `get_method` may throw if relevant method is not found, so we have guarded it with `has_method`.
459469
method = torchctx.get_method(node.target, args, kwargs)
460-
did_run, opt_split_reason = try_execute_thunder_symbol(method, node)
470+
did_run, opt_split_reason = try_execute_thunder_symbol(method, node, thunder_options)
461471
return did_run, opt_split_reason
462472

463473
# checks einops operators

0 commit comments

Comments
 (0)