Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
partial(jit, **thunder_options),
self._torch_compile,
sample_args,
thunder_options,
)
self.subgraph_infos.append(subgraph_info)
return split_module
Expand Down
3 changes: 2 additions & 1 deletion thunder/dynamo/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _splitter(
thunder_jit: Callable,
torch_inductor: Callable,
_unused_sample_args: list[torch.SymInt, torch.Tensor],
thunder_options: dict[str, Any],
) -> tuple[torch.fx.GraphModule, SubgraphInfo]:
"""
This method will split graph into multiple graph modules based on thunder supported operations.
Expand Down Expand Up @@ -119,7 +120,7 @@ def callback(node) -> int:
)
split_reasons.append(split_reason)
else:
is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
if split_reason is not None:
split_reasons.append(split_reason)

Expand Down
25 changes: 17 additions & 8 deletions thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def make_input_proxy(arg_node):
return proxy_args, proxy_kwargs


def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
def try_execute_thunder_symbol(
thunder_symbol: Symbol, node: torch.fx.Node, thunder_options: dict[str, Any]
) -> tuple[bool, SplitReason | None]:
"""
Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments.

Expand Down Expand Up @@ -287,6 +289,7 @@ def get_requires_grad(arg_node):

args, _ = tree_flatten((node.args, node.kwargs))
requires_grad = any(map(get_requires_grad, args))
disable_torch_autograd: bool | None = thunder_options.get("disable_torch_autograd", None)

@compile_data_and_stats(cd, cs)
@thunder._with_cache_info_ctx
Expand All @@ -309,7 +312,11 @@ def _run_with_cache_info():
exception=str(e),
)

function_to_run = value_and_grad(thunder_symbol) if requires_grad else thunder_symbol
function_to_run = (
value_and_grad(thunder_symbol)
if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd)
else thunder_symbol
)
# We need to be under trace context to generate proxies.
with thunder.core.trace.tracectx(TraceCtx()):
try:
Expand Down Expand Up @@ -351,7 +358,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
return nodes_in_unsupported_ctx_regions


def is_graphmodule_supported_by_thunder(gm):
def is_graphmodule_supported_by_thunder(gm, thunder_options: dict[str, Any]):
nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm)
for node in gm.graph.nodes:
if node.op in (
Expand All @@ -367,13 +374,15 @@ def is_graphmodule_supported_by_thunder(gm):
)
return False, split_reason

is_thunder_supported, split_reason = is_node_supported_by_thunder(node)
is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options)
if not is_thunder_supported:
return False, split_reason
return True, None


def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]:
def is_node_supported_by_thunder(
node: torch.fx.Node, thunder_options: dict[str, Any]
) -> tuple[bool, SplitReason | None]:
"""
Determine whether thunder can execute the operation described by this node.
"""
Expand Down Expand Up @@ -425,7 +434,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
for arg_node in node.args:
if arg_node.op == "get_attr":
called_module = getattr(m, arg_node.target)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module)
is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module, thunder_options)
if not is_module_supported:
return is_module_supported, split_reason
return True, None
Expand All @@ -438,7 +447,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
# We try to proxify the arguments and call these operations on them to see if they are supported.
if target in _torch_to_thunder_function_map or inspect.isbuiltin(target):
thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target)
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node)
did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node, thunder_options)
return did_run, opt_split_reason

# There are few operations which are registered only as method in `torchctx` and hence they don't exist
Expand All @@ -457,7 +466,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason
)
# NOTE: `get_method` may throw if relevant method is not found, so we have guarded it with `has_method`.
method = torchctx.get_method(node.target, args, kwargs)
did_run, opt_split_reason = try_execute_thunder_symbol(method, node)
did_run, opt_split_reason = try_execute_thunder_symbol(method, node, thunder_options)
return did_run, opt_split_reason

# checks einops operators
Expand Down
Loading