Skip to content

Commit 9971ae6

Browse files
committed
fix
1 parent 95f0063 commit 9971ae6

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/accelerate/hooks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def init_hook(self, module):
111111
module = hook.init_hook(module)
112112
return module
113113

114-
@torch.compiler.disable()
114+
@torch.compiler.disable
115115
def pre_forward(self, module, *args, **kwargs):
116116
for hook in self.hooks:
117117
args, kwargs = hook.pre_forward(module, *args, **kwargs)
118118
return args, kwargs
119119

120-
@torch.compiler.disable()
120+
@torch.compiler.disable
121121
def post_forward(self, module, output):
122122
for hook in self.hooks:
123123
output = hook.post_forward(module, output)
@@ -327,7 +327,7 @@ def init_hook(self, module):
327327

328328
return module
329329

330-
@torch.compiler.disable()
330+
@torch.compiler.disable
331331
def pre_forward(self, module, *args, **kwargs):
332332
if self.io_same_device:
333333
self.input_device = find_device([args, kwargs])
@@ -373,7 +373,7 @@ def pre_forward(self, module, *args, **kwargs):
373373
kwargs, self.execution_device, skip_keys=self.skip_keys
374374
)
375375

376-
@torch.compiler.disable()
376+
@torch.compiler.disable
377377
def post_forward(self, module, output):
378378
if self.offload:
379379
for name, _ in named_module_tensors(
@@ -717,7 +717,7 @@ def __init__(
717717
def init_hook(self, module):
718718
return module.to("cpu")
719719

720-
@torch.compiler.disable()
720+
@torch.compiler.disable
721721
def pre_forward(self, module, *args, **kwargs):
722722
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
723723
prev_module = self.prev_module_hook.model
@@ -772,12 +772,12 @@ def init_hook(self, module: torch.nn.Module):
772772
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
773773
return module
774774

775-
@torch.compiler.disable()
775+
@torch.compiler.disable
776776
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
777777
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
778778
return args, kwargs
779779

780-
@torch.compiler.disable()
780+
@torch.compiler.disable
781781
def post_forward(self, module: torch.nn.Module, output):
782782
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
783783
return output

0 commit comments

Comments
 (0)