diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 5d6a22e0a9e..8e45141307b 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -111,11 +111,13 @@ def init_hook(self, module): module = hook.init_hook(module) return module + @torch.compiler.disable def pre_forward(self, module, *args, **kwargs): for hook in self.hooks: args, kwargs = hook.pre_forward(module, *args, **kwargs) return args, kwargs + @torch.compiler.disable def post_forward(self, module, output): for hook in self.hooks: output = hook.post_forward(module, output) @@ -325,6 +327,7 @@ def init_hook(self, module): return module + @torch.compiler.disable def pre_forward(self, module, *args, **kwargs): if self.io_same_device: self.input_device = find_device([args, kwargs]) @@ -370,6 +373,7 @@ def pre_forward(self, module, *args, **kwargs): kwargs, self.execution_device, skip_keys=self.skip_keys ) + @torch.compiler.disable def post_forward(self, module, output): if self.offload: for name, _ in named_module_tensors( @@ -713,6 +717,7 @@ def __init__( def init_hook(self, module): return module.to("cpu") + @torch.compiler.disable def pre_forward(self, module, *args, **kwargs): if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook): prev_module = self.prev_module_hook.model @@ -767,10 +772,12 @@ def init_hook(self, module: torch.nn.Module): module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) return module + @torch.compiler.disable def pre_forward(self, module: torch.nn.Module, *args, **kwargs): module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) return args, kwargs + @torch.compiler.disable def post_forward(self, module: torch.nn.Module, output): module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) return output