Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ def init_hook(self, module):
module = hook.init_hook(module)
return module

@torch.compiler.disable
def pre_forward(self, module, *args, **kwargs):
for hook in self.hooks:
args, kwargs = hook.pre_forward(module, *args, **kwargs)
return args, kwargs

@torch.compiler.disable
def post_forward(self, module, output):
for hook in self.hooks:
output = hook.post_forward(module, output)
Expand Down Expand Up @@ -325,6 +327,7 @@ def init_hook(self, module):

return module

@torch.compiler.disable
def pre_forward(self, module, *args, **kwargs):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
Expand Down Expand Up @@ -370,6 +373,7 @@ def pre_forward(self, module, *args, **kwargs):
kwargs, self.execution_device, skip_keys=self.skip_keys
)

@torch.compiler.disable
def post_forward(self, module, output):
if self.offload:
for name, _ in named_module_tensors(
Expand Down Expand Up @@ -713,6 +717,7 @@ def __init__(
def init_hook(self, module):
return module.to("cpu")

@torch.compiler.disable
def pre_forward(self, module, *args, **kwargs):
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
prev_module = self.prev_module_hook.model
Expand Down Expand Up @@ -767,10 +772,12 @@ def init_hook(self, module: torch.nn.Module):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return module

@torch.compiler.disable
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
return args, kwargs

@torch.compiler.disable
def post_forward(self, module: torch.nn.Module, output):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return output
Loading