1313# limitations under the License. 
1414
1515import  functools 
16- from  typing  import  Any , Callable ,  Dict , Tuple 
16+ from  typing  import  Any , Dict , Tuple 
1717
1818import  torch 
1919
@@ -28,6 +28,7 @@ class ModelHook:
2828    def  init_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
2929        r""" 
3030        Hook that is executed when a model is initialized. 
31+ 
3132        Args: 
3233            module (`torch.nn.Module`): 
3334                The module attached to this hook. 
@@ -37,6 +38,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3738    def  pre_forward (self , module : torch .nn .Module , * args , ** kwargs ) ->  Tuple [Tuple [Any ], Dict [str , Any ]]:
3839        r""" 
3940        Hook that is executed just before the forward method of the model. 
41+ 
4042        Args: 
4143            module (`torch.nn.Module`): 
4244                The module whose forward pass will be executed just after this event. 
@@ -53,6 +55,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
5355    def  post_forward (self , module : torch .nn .Module , output : Any ) ->  Any :
5456        r""" 
5557        Hook that is executed just after the forward method of the model. 
58+ 
5659        Args: 
5760            module (`torch.nn.Module`): 
5861                The module whose forward pass been executed just before this event. 
@@ -66,15 +69,13 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
6669    def  detach_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
6770        r""" 
6871        Hook that is executed when the hook is detached from a module. 
72+ 
6973        Args: 
7074            module (`torch.nn.Module`): 
7175                The module detached from this hook. 
7276        """ 
7377        return  module 
7478
75-     def  reset_state (self , module : torch .nn .Module ) ->  torch .nn .Module :
76-         return  module 
77- 
7879
7980class  SequentialHook (ModelHook ):
8081    r"""A hook that can contain several hooks and iterates through them at each event.""" 
@@ -102,52 +103,19 @@ def detach_hook(self, module):
102103            module  =  hook .detach_hook (module )
103104        return  module 
104105
105-     def  reset_state (self , module ):
106-         for  hook  in  self .hooks :
107-             module  =  hook .reset_state (module )
108-         return  module 
109- 
110- 
111- class  FasterCacheHook (ModelHook ):
112-     def  __init__ (
113-         self ,
114-         skip_callback : Callable [[torch .nn .Module ], bool ],
115-     ) ->  None :
116-         super ().__init__ ()
117- 
118-         self .skip_callback  =  skip_callback 
119- 
120-         self .cache  =  None 
121-         self ._iteration  =  0 
122- 
123-     def  new_forward (self , module : torch .nn .Module , * args , ** kwargs ) ->  Any :
124-         args , kwargs  =  module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
125- 
126-         if  self .cache  is  not None  and  self .skip_callback (module ):
127-             output  =  self .cache 
128-         else :
129-             output  =  module ._old_forward (* args , ** kwargs )
130- 
131-         return  module ._diffusers_hook .post_forward (module , output )
132- 
133-     def  post_forward (self , module : torch .nn .Module , output : Any ) ->  Any :
134-         self .cache  =  output 
135-         return  output 
136- 
137-     def  reset_state (self , module : torch .nn .Module ) ->  torch .nn .Module :
138-         self .cache  =  None 
139-         self ._iteration  =  0 
140-         return  module 
141- 
142106
143107def  add_hook_to_module (module : torch .nn .Module , hook : ModelHook , append : bool  =  False ):
144108    r""" 
145109    Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove 
146110    this behavior and restore the original `forward` method, use `remove_hook_from_module`. 
111+ 
147112    <Tip warning={true}> 
113+ 
148114    If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks 
149115    together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. 
116+ 
150117    </Tip> 
118+ 
151119    Args: 
152120        module (`torch.nn.Module`): 
153121            The module to attach a hook to. 
@@ -198,6 +166,7 @@ def new_forward(module, *args, **kwargs):
198166def  remove_hook_from_module (module : torch .nn .Module , recurse : bool  =  False ) ->  torch .nn .Module :
199167    """ 
200168    Removes any hook attached to a module via `add_hook_to_module`. 
169+ 
201170    Args: 
202171        module (`torch.nn.Module`): 
203172            The module to attach a hook to. 
0 commit comments