Skip to content

Commit db8835c

Browse files
authored
Disable hook compile (#3888)
* disable hook compile * fix
1 parent 16b6b3f commit db8835c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/accelerate/hooks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,13 @@ def init_hook(self, module):
111111
module = hook.init_hook(module)
112112
return module
113113

114+
@torch.compiler.disable
114115
def pre_forward(self, module, *args, **kwargs):
115116
for hook in self.hooks:
116117
args, kwargs = hook.pre_forward(module, *args, **kwargs)
117118
return args, kwargs
118119

120+
@torch.compiler.disable
119121
def post_forward(self, module, output):
120122
for hook in self.hooks:
121123
output = hook.post_forward(module, output)
@@ -325,6 +327,7 @@ def init_hook(self, module):
325327

326328
return module
327329

330+
@torch.compiler.disable
328331
def pre_forward(self, module, *args, **kwargs):
329332
if self.io_same_device:
330333
self.input_device = find_device([args, kwargs])
@@ -370,6 +373,7 @@ def pre_forward(self, module, *args, **kwargs):
370373
kwargs, self.execution_device, skip_keys=self.skip_keys
371374
)
372375

376+
@torch.compiler.disable
373377
def post_forward(self, module, output):
374378
if self.offload:
375379
for name, _ in named_module_tensors(
@@ -713,6 +717,7 @@ def __init__(
713717
def init_hook(self, module):
714718
return module.to("cpu")
715719

720+
@torch.compiler.disable
716721
def pre_forward(self, module, *args, **kwargs):
717722
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
718723
prev_module = self.prev_module_hook.model
@@ -767,10 +772,12 @@ def init_hook(self, module: torch.nn.Module):
767772
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
768773
return module
769774

775+
@torch.compiler.disable
770776
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
771777
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
772778
return args, kwargs
773779

780+
@torch.compiler.disable
774781
def post_forward(self, module: torch.nn.Module, output):
775782
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
776783
return output

0 commit comments

Comments
 (0)