Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
partial(jit, **thunder_options),
self._torch_compile,
sample_args,
thunder_options,
)
self.subgraph_infos.append(subgraph_info)
return split_module
Expand Down
3 changes: 2 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _splitter(
thunder_jit: Callable,
torch_inductor: Callable,
_unused_sample_args: list[torch.SymInt, torch.Tensor],
thunder_options: dict[str, Any],
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
"""
This method will split graph into multiple graph modules based on thunder supported operations.
Expand Down Expand Up @@ -119,7 +120,7 @@ def callback(node) -> int:
)
split_reasons.append(split_reason)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
if split_reason is not None:
split_reasons.append(split_reason)

Expand Down
26 changes: 18 additions & 8 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def make_input_proxy(arg_node):
return proxy_args, proxy_kwargs


def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
def try_execute_thunder_symbol(
thunder_symbol: Symbol, node: torch.fx.Node, thunder_options: dict[str, Any]
) -> tuple[bool, SplitReason | None]:
"""
Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments.

Expand Down Expand Up @@ -287,6 +289,7 @@ def get_requires_grad(arg_node):

args, _ = tree_flatten((node.args, node.kwargs))
requires_grad = any(map(get_requires_grad, args))
disable_torch_autograd: bool | None = thunder_options.get("disable_torch_autograd", None)

@compile_data_and_stats(cd, cs)
@thunder._with_cache_info_ctx
Expand All @@ -309,7 +312,12 @@ def _run_with_cache_info():
exception=str(e),
)

function_to_run = value_and_grad(thunder_symbol) if requires_grad else thunder_symbol
function_to_run = thunder_symbol
Copy link
Collaborator

@shino16 shino16 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this line redundant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, great catch.

function_to_run = (
value_and_grad(thunder_symbol)
if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd)
else thunder_symbol
)
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
try:
Expand Down Expand Up @@ -351,7 +359,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
return nodes_in_unsupported_ctx_regions


def is_graphmodule_supported_by_thunder(gm):
def is_graphmodule_supported_by_thunder(gm, thunder_options: dict[str, Any]):
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
for node in gm.graph.nodes:
if node.op in (
Expand All @@ -367,13 +375,15 @@ def is_graphmodule_supported_by_thunder(gm):
)
return False, split_reason

is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
if not is_thunder_supported:
return False, split_reason
return True, None


def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
def is_node_supported_by_thunder(
node: torch.fx.Node, thunder_options: dict[str, Any]
) -> tuple[bool, SplitReason | None]:
"""
Determine whether thunder can execute the operation described by this node.
"""
Expand Down Expand Up @@ -425,7 +435,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
for arg_node in node.args:
if arg_node.op == "get_attr":
called_module = getattr(m, arg_node.target)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module, thunder_options)
if not is_module_supported:
return is_module_supported, split_reason
return True, None
Expand All @@ -438,7 +448,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
# We try to proxify the arguments and call these operations on them to see if they are supported.
if target in _torch_to_thunder_function_map or inspect.isbuiltin(target):
thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target)
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node)
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node, thunder_options)
return did_run, opt_split_reason

# There are few operations which are registered only as method in `torchctx` and hence they don't exist
Expand All @@ -457,7 +467,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
)
# NOTE: `get_method` may throw if relevant method is not found, so we have guarded it with `has_method`.
method = torchctx.get_method(node.target, args, kwargs)
did_run, opt_split_reason = try_execute_thunder_symbol(method, node)
did_run, opt_split_reason = try_execute_thunder_symbol(method, node, thunder_options)
return did_run, opt_split_reason

# checks einops operators
Expand Down
Loading