@@ -24,6 +24,7 @@ class CacheMixin:
2424
2525    Supported caching techniques: 
2626        - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) 
27+         - [FasterCache](https://huggingface.co/papers/2410.19355) 
2728    """ 
2829
2930    _cache_config  =  None 
@@ -59,25 +60,43 @@ def enable_cache(self, config) -> None:
5960        ``` 
6061        """ 
6162
62-         from  ..hooks  import  PyramidAttentionBroadcastConfig , apply_pyramid_attention_broadcast 
63+         from  ..hooks  import  (
64+             FasterCacheConfig ,
65+             PyramidAttentionBroadcastConfig ,
66+             apply_faster_cache ,
67+             apply_pyramid_attention_broadcast ,
68+         )
69+ 
70+         if  self .is_cache_enabled :
71+             raise  ValueError (
72+                 f"Caching has already been enabled with { type (self ._cache_config )}  
73+             )
6374
6475        if  isinstance (config , PyramidAttentionBroadcastConfig ):
6576            apply_pyramid_attention_broadcast (self , config )
77+         elif  isinstance (config , FasterCacheConfig ):
78+             apply_faster_cache (self , config )
6679        else :
6780            raise  ValueError (f"Cache config { type (config )}  )
6881
6982        self ._cache_config  =  config 
7083
7184    def  disable_cache (self ) ->  None :
72-         from  ..hooks  import  HookRegistry , PyramidAttentionBroadcastConfig 
85+         from  ..hooks  import  FasterCacheConfig , HookRegistry , PyramidAttentionBroadcastConfig 
86+         from  ..hooks .faster_cache  import  _FASTER_CACHE_BLOCK_HOOK , _FASTER_CACHE_DENOISER_HOOK 
87+         from  ..hooks .pyramid_attention_broadcast  import  _PYRAMID_ATTENTION_BROADCAST_HOOK 
7388
7489        if  self ._cache_config  is  None :
7590            logger .warning ("Caching techniques have not been enabled, so there's nothing to disable." )
7691            return 
7792
7893        if  isinstance (self ._cache_config , PyramidAttentionBroadcastConfig ):
7994            registry  =  HookRegistry .check_if_exists_or_initialize (self )
80-             registry .remove_hook ("pyramid_attention_broadcast" , recurse = True )
95+             registry .remove_hook (_PYRAMID_ATTENTION_BROADCAST_HOOK , recurse = True )
96+         elif  isinstance (self ._cache_config , FasterCacheConfig ):
97+             registry  =  HookRegistry .check_if_exists_or_initialize (self )
98+             registry .remove_hook (_FASTER_CACHE_DENOISER_HOOK , recurse = True )
99+             registry .remove_hook (_FASTER_CACHE_BLOCK_HOOK , recurse = True )
81100        else :
82101            raise  ValueError (f"Cache config { type (self ._cache_config )}  )
83102
0 commit comments