diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 2ee582bba0..6190eef99f 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -128,7 +128,6 @@ def __init__(self, **thunder_options): "thunderfx_disable_split_autograd", _DEFAULT_THUNDERFX_DISABLE_SPLIT_AUTOGRAD ) self.thunder_options = thunder_options - self._torch_compile = torch.compile def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): from thunder import jit @@ -148,7 +147,7 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor split_module, subgraph_info = _splitter( gm, partial(jit, **thunder_options), - self._torch_compile, + torch.compile, sample_args, ) self.subgraph_infos.append(subgraph_info)