@@ -111,11 +111,13 @@ def init_hook(self, module):
111111 module = hook .init_hook (module )
112112 return module
113113
114+ @torch .compiler .disable
114115 def pre_forward (self , module , * args , ** kwargs ):
115116 for hook in self .hooks :
116117 args , kwargs = hook .pre_forward (module , * args , ** kwargs )
117118 return args , kwargs
118119
120+ @torch .compiler .disable
119121 def post_forward (self , module , output ):
120122 for hook in self .hooks :
121123 output = hook .post_forward (module , output )
@@ -325,6 +327,7 @@ def init_hook(self, module):
325327
326328 return module
327329
330+ @torch .compiler .disable
328331 def pre_forward (self , module , * args , ** kwargs ):
329332 if self .io_same_device :
330333 self .input_device = find_device ([args , kwargs ])
@@ -370,6 +373,7 @@ def pre_forward(self, module, *args, **kwargs):
370373 kwargs , self .execution_device , skip_keys = self .skip_keys
371374 )
372375
376+ @torch .compiler .disable
373377 def post_forward (self , module , output ):
374378 if self .offload :
375379 for name , _ in named_module_tensors (
@@ -713,6 +717,7 @@ def __init__(
713717 def init_hook (self , module ):
714718 return module .to ("cpu" )
715719
720+ @torch .compiler .disable
716721 def pre_forward (self , module , * args , ** kwargs ):
717722 if self .prev_module_hook is not None and isinstance (self .prev_module_hook , UserCpuOffloadHook ):
718723 prev_module = self .prev_module_hook .model
@@ -767,10 +772,12 @@ def init_hook(self, module: torch.nn.Module):
767772 module .to (dtype = self .storage_dtype , non_blocking = self .non_blocking )
768773 return module
769774
775+ @torch .compiler .disable
770776 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ):
771777 module .to (dtype = self .compute_dtype , non_blocking = self .non_blocking )
772778 return args , kwargs
773779
780+ @torch .compiler .disable
774781 def post_forward (self , module : torch .nn .Module , output ):
775782 module .to (dtype = self .storage_dtype , non_blocking = self .non_blocking )
776783 return output
0 commit comments