1818import  torch 
1919
2020from  ..utils .logging  import  get_logger 
21+ from  ..utils .torch_utils  import  unwrap_module 
2122
2223
2324logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
@@ -47,7 +48,7 @@ def get_current_state(self) -> "BaseMarkedState":
4748            self ._state_cache [self ._mark_name ] =  self .__class__ (* self ._init_args , ** self ._init_kwargs )
4849        return  self ._state_cache [self ._mark_name ]
4950
50-     def  mark_batch (self , name : str ) ->  None :
51+     def  mark_state (self , name : str ) ->  None :
5152        self ._mark_name  =  name 
5253
5354    def  reset (self , * args , ** kwargs ) ->  None :
@@ -59,7 +60,7 @@ def reset(self, *args, **kwargs) -> None:
5960    def  __getattribute__ (self , name ):
6061        if  name  in  (
6162            "get_current_state" ,
62-             "mark_batch " ,
63+             "mark_state " ,
6364            "reset" ,
6465            "_init_args" ,
6566            "_init_kwargs" ,
@@ -74,7 +75,7 @@ def __getattribute__(self, name):
7475    def  __setattr__ (self , name , value ):
7576        if  name  in  (
7677            "get_current_state" ,
77-             "mark_batch " ,
78+             "mark_state " ,
7879            "reset" ,
7980            "_init_args" ,
8081            "_init_kwargs" ,
@@ -164,11 +165,11 @@ def reset_state(self, module: torch.nn.Module):
164165        return  module 
165166
166167    def  _mark_state (self , module : torch .nn .Module , name : str ) ->  None :
167-         # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch ` on them. 
168+         # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state ` on them. 
168169        for  attr_name  in  dir (self ):
169170            attr  =  getattr (self , attr_name )
170171            if  isinstance (attr , BaseMarkedState ):
171-                 attr .mark_batch (name )
172+                 attr .mark_state (name )
172173        return  module 
173174
174175
@@ -283,9 +284,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None:
283284                hook .reset_state (self ._module_ref )
284285
285286        if  recurse :
286-             for  module_name , module  in  self ._module_ref .named_modules ():
287+             for  module_name , module  in  unwrap_module ( self ._module_ref ) .named_modules ():
287288                if  module_name  ==  "" :
288289                    continue 
290+                 module  =  unwrap_module (module )
289291                if  hasattr (module , "_diffusers_hook" ):
290292                    module ._diffusers_hook .reset_stateful_hooks (recurse = False )
291293
@@ -301,9 +303,10 @@ def _mark_state(self, name: str) -> None:
301303            if  hook ._is_stateful :
302304                hook ._mark_state (self ._module_ref , name )
303305
304-         for  module_name , module  in  self ._module_ref .named_modules ():
306+         for  module_name , module  in  unwrap_module ( self ._module_ref ) .named_modules ():
305307            if  module_name  ==  "" :
306308                continue 
309+             module  =  unwrap_module (module )
307310            if  hasattr (module , "_diffusers_hook" ):
308311                module ._diffusers_hook ._mark_state (name )
309312
0 commit comments