@@ -66,10 +66,12 @@ def enable_cache(self, config) -> None:
6666 from ..hooks import (
6767 FasterCacheConfig ,
6868 FirstBlockCacheConfig ,
69+ MagCacheConfig ,
6970 PyramidAttentionBroadcastConfig ,
7071 TaylorSeerCacheConfig ,
7172 apply_faster_cache ,
7273 apply_first_block_cache ,
74+ apply_mag_cache ,
7375 apply_pyramid_attention_broadcast ,
7476 apply_taylorseer_cache ,
7577 )
@@ -83,6 +85,8 @@ def enable_cache(self, config) -> None:
8385 apply_faster_cache (self , config )
8486 elif isinstance (config , FirstBlockCacheConfig ):
8587 apply_first_block_cache (self , config )
88+ elif isinstance (config , MagCacheConfig ):
89+ apply_mag_cache (self , config )
8690 elif isinstance (config , PyramidAttentionBroadcastConfig ):
8791 apply_pyramid_attention_broadcast (self , config )
8892 elif isinstance (config , TaylorSeerCacheConfig ):
@@ -97,11 +101,13 @@ def disable_cache(self) -> None:
97101 FasterCacheConfig ,
98102 FirstBlockCacheConfig ,
99103 HookRegistry ,
104+ MagCacheConfig ,
100105 PyramidAttentionBroadcastConfig ,
101106 TaylorSeerCacheConfig ,
102107 )
103108 from ..hooks .faster_cache import _FASTER_CACHE_BLOCK_HOOK , _FASTER_CACHE_DENOISER_HOOK
104109 from ..hooks .first_block_cache import _FBC_BLOCK_HOOK , _FBC_LEADER_BLOCK_HOOK
110+ from ..hooks .mag_cache import _MAG_CACHE_LEADER_BLOCK_HOOK
105111 from ..hooks .pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
106112 from ..hooks .taylorseer_cache import _TAYLORSEER_CACHE_HOOK
107113
@@ -116,6 +122,8 @@ def disable_cache(self) -> None:
116122 elif isinstance (self ._cache_config , FirstBlockCacheConfig ):
117123 registry .remove_hook (_FBC_LEADER_BLOCK_HOOK , recurse = True )
118124 registry .remove_hook (_FBC_BLOCK_HOOK , recurse = True )
125+ elif isinstance (self ._cache_config , MagCacheConfig ):
126+ registry .remove_hook (_MAG_CACHE_LEADER_BLOCK_HOOK , recurse = True )
119127 elif isinstance (self ._cache_config , PyramidAttentionBroadcastConfig ):
120128 registry .remove_hook (_PYRAMID_ATTENTION_BROADCAST_HOOK , recurse = True )
121129 elif isinstance (self ._cache_config , TaylorSeerCacheConfig ):
0 commit comments