diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..b27f2590e4 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -150,6 +150,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor partial(jit, **thunder_options), self._torch_compile, sample_args, + thunder_options, ) self.subgraph_infos.append(subgraph_info) return split_module diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 7949937ca5..b68fcfff05 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -30,6 +30,7 @@ def _splitter( thunder_jit: Callable, torch_inductor: Callable, _unused_sample_args: list[torch.SymInt, torch.Tensor], + thunder_options: dict[str, Any], ) -> tuple[torch.fx.GraphModule, SubgraphInfo]: """ This method will split graph into multiple graph modules based on thunder supported operations. @@ -119,7 +120,7 @@ def callback(node) -> int: ) split_reasons.append(split_reason) else: - is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options) if split_reason is not None: split_reasons.append(split_reason) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 3c58bc2ded..b9bb8f9c16 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -243,7 +243,9 @@ def make_input_proxy(arg_node): return proxy_args, proxy_kwargs -def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]: +def try_execute_thunder_symbol( + thunder_symbol: Symbol, node: torch.fx.Node, thunder_options: dict[str, Any] +) -> tuple[bool, SplitReason | None]: """ Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. @@ -287,6 +289,7 @@ def get_requires_grad(arg_node): args, _ = tree_flatten((node.args, node.kwargs)) requires_grad = any(map(get_requires_grad, args)) + disable_torch_autograd: bool | None = thunder_options.get("disable_torch_autograd", None) @compile_data_and_stats(cd, cs) @thunder._with_cache_info_ctx @@ -309,7 +312,11 @@ def _run_with_cache_info(): exception=str(e), ) - function_to_run = value_and_grad(thunder_symbol) if requires_grad else thunder_symbol + function_to_run = ( + value_and_grad(thunder_symbol) + if requires_grad and (disable_torch_autograd is None or not disable_torch_autograd) + else thunder_symbol + ) # We need to be under trace context to generate proxies. with thunder.core.trace.tracectx(TraceCtx()): try: @@ -351,7 +358,7 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch. return nodes_in_unsupported_ctx_regions -def is_graphmodule_supported_by_thunder(gm): +def is_graphmodule_supported_by_thunder(gm, thunder_options: dict[str, Any]): nodes_in_unsupported_ctx_regions = get_nodes_in_unsupported_ctx_regions(gm) for node in gm.graph.nodes: if node.op in ( @@ -367,13 +374,15 @@ def is_graphmodule_supported_by_thunder(gm): ) return False, split_reason - is_thunder_supported, split_reason = is_node_supported_by_thunder(node) + is_thunder_supported, split_reason = is_node_supported_by_thunder(node, thunder_options) if not is_thunder_supported: return False, split_reason return True, None -def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason | None]: +def is_node_supported_by_thunder( + node: torch.fx.Node, thunder_options: dict[str, Any] +) -> tuple[bool, SplitReason | None]: """ Determine whether thunder can execute the operation described by this node. """ @@ -425,7 +434,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason for arg_node in node.args: if arg_node.op == "get_attr": called_module = getattr(m, arg_node.target) - is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module) + is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module, thunder_options) if not is_module_supported: return is_module_supported, split_reason return True, None @@ -438,7 +447,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason # We try to proxify the arguments and call these operations on them to see if they are supported. if target in _torch_to_thunder_function_map or inspect.isbuiltin(target): thunder_symbol_or_builtin = _torch_to_thunder_function_map.get(target, target) - did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node) + did_run, opt_split_reason = try_execute_thunder_symbol(thunder_symbol_or_builtin, node, thunder_options) return did_run, opt_split_reason # There are few operations which are registered only as method in `torchctx` and hence they don't exist @@ -457,7 +466,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason ) # NOTE: `get_method` may throw if relevant method is not found, so we have guarded it with `has_method`. method = torchctx.get_method(node.target, args, kwargs) - did_run, opt_split_reason = try_execute_thunder_symbol(method, node) + did_run, opt_split_reason = try_execute_thunder_symbol(method, node, thunder_options) return did_run, opt_split_reason # checks einops operators