@@ -111,13 +111,13 @@ def init_hook(self, module):
111111 module = hook .init_hook (module )
112112 return module
113113
114- @torch .compiler .disable ()
114+ @torch .compiler .disable
115115 def pre_forward (self , module , * args , ** kwargs ):
116116 for hook in self .hooks :
117117 args , kwargs = hook .pre_forward (module , * args , ** kwargs )
118118 return args , kwargs
119119
120- @torch .compiler .disable ()
120+ @torch .compiler .disable
121121 def post_forward (self , module , output ):
122122 for hook in self .hooks :
123123 output = hook .post_forward (module , output )
@@ -327,7 +327,7 @@ def init_hook(self, module):
327327
328328 return module
329329
330- @torch .compiler .disable ()
330+ @torch .compiler .disable
331331 def pre_forward (self , module , * args , ** kwargs ):
332332 if self .io_same_device :
333333 self .input_device = find_device ([args , kwargs ])
@@ -373,7 +373,7 @@ def pre_forward(self, module, *args, **kwargs):
373373 kwargs , self .execution_device , skip_keys = self .skip_keys
374374 )
375375
376- @torch .compiler .disable ()
376+ @torch .compiler .disable
377377 def post_forward (self , module , output ):
378378 if self .offload :
379379 for name , _ in named_module_tensors (
@@ -717,7 +717,7 @@ def __init__(
717717 def init_hook (self , module ):
718718 return module .to ("cpu" )
719719
720- @torch .compiler .disable ()
720+ @torch .compiler .disable
721721 def pre_forward (self , module , * args , ** kwargs ):
722722 if self .prev_module_hook is not None and isinstance (self .prev_module_hook , UserCpuOffloadHook ):
723723 prev_module = self .prev_module_hook .model
@@ -772,12 +772,12 @@ def init_hook(self, module: torch.nn.Module):
772772 module .to (dtype = self .storage_dtype , non_blocking = self .non_blocking )
773773 return module
774774
775- @torch .compiler .disable ()
775+ @torch .compiler .disable
776776 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
777777 module .to (dtype = self .compute_dtype , non_blocking = self .non_blocking )
778778 return args , kwargs
779779
780- @torch .compiler .disable ()
780+ @torch .compiler .disable
781781 def post_forward (self , module : torch .nn .Module , output ):
782782 module .to (dtype = self .storage_dtype , non_blocking = self .non_blocking )
783783 return output
0 commit comments