1313# limitations under the License.
1414
1515from dataclasses import dataclass
16- from typing import List , Optional
16+ from typing import Callable , List , Optional
1717
1818import torch
1919
2020from ..utils import get_logger
2121from ..utils .torch_utils import unwrap_module
22- from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
23- from ._helpers import TransformerBlockRegistry
22+ from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS , _ATTENTION_CLASSES , _FEEDFORWARD_CLASSES
23+ from ._helpers import AttentionProcessorRegistry , TransformerBlockRegistry
2424from .hooks import HookRegistry , ModelHook
2525
2626
@@ -44,9 +44,50 @@ class LayerSkipConfig:
4444
4545 indices : List [int ]
4646 fqn : str = "auto"
47+ skip_attention : bool = True
48+ skip_attention_scores : bool = False
49+ skip_ff : bool = True
4750
4851
49- class LayerSkipHook (ModelHook ):
52+ class AttentionScoreSkipFunctionMode (torch .overrides .TorchFunctionMode ):
53+ def __init__ (self ) -> None :
54+ super ().__init__ ()
55+
56+ def __torch_function__ (self , func , types , args = (), kwargs = None ):
57+ if kwargs is None :
58+ kwargs = {}
59+ if func is torch .nn .functional .scaled_dot_product_attention :
60+ value = kwargs .get ("value" , None )
61+ if value is None :
62+ value = args [2 ]
63+ return value
64+ return func (* args , ** kwargs )
65+
66+
67+ class AttentionProcessorSkipHook (ModelHook ):
68+ def __init__ (self , skip_processor_output_fn : Callable , skip_attention_scores : bool = False ):
69+ self .skip_processor_output_fn = skip_processor_output_fn
70+ self .skip_attention_scores = skip_attention_scores
71+
72+ def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
73+ if self .skip_attention_scores :
74+ with AttentionScoreSkipFunctionMode ():
75+ return self .fn_ref .original_forward (* args , ** kwargs )
76+ else :
77+ return self .skip_processor_output_fn (module , * args , ** kwargs )
78+
79+
80+ class FeedForwardSkipHook (ModelHook ):
81+ def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
82+ output = kwargs .get ("hidden_states" , None )
83+ if output is None :
84+ output = kwargs .get ("x" , None )
85+ if output is None and len (args ) > 0 :
86+ output = args [0 ]
87+ return output
88+
89+
90+ class TransformerBlockSkipHook (ModelHook ):
5091 def initialize_hook (self , module ):
5192 self ._metadata = TransformerBlockRegistry .get (unwrap_module (module ).__class__ )
5293 return module
@@ -81,6 +122,9 @@ def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
81122def _apply_layer_skip_hook (module : torch .nn .Module , config : LayerSkipConfig , name : Optional [str ] = None ) -> None :
82123 name = name or _LAYER_SKIP_HOOK
83124
125+ if config .skip_attention and config .skip_attention_scores :
126+ raise ValueError ("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one." )
127+
84128 if config .fqn == "auto" :
85129 for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS :
86130 if hasattr (module , identifier ):
@@ -101,10 +145,38 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam
101145 if len (config .indices ) == 0 :
102146 raise ValueError ("Layer index list is empty. Please provide a non-empty list of layer indices to skip." )
103147
148+ blocks_found = False
104149 for i , block in enumerate (transformer_blocks ):
105150 if i not in config .indices :
106151 continue
107- logger .debug (f"Apply LayerSkipHook to '{ config .fqn } .{ i } '" )
108- registry = HookRegistry .check_if_exists_or_initialize (block )
109- hook = LayerSkipHook ()
110- registry .register_hook (hook , name )
152+ blocks_found = True
153+ if config .skip_attention and config .skip_ff :
154+ logger .debug (f"Applying TransformerBlockSkipHook to '{ config .fqn } .{ i } '" )
155+ registry = HookRegistry .check_if_exists_or_initialize (block )
156+ hook = TransformerBlockSkipHook ()
157+ registry .register_hook (hook , name )
158+ elif config .skip_attention or config .skip_attention_scores :
159+ for submodule_name , submodule in block .named_modules ():
160+ if isinstance (submodule , _ATTENTION_CLASSES ) and not submodule .is_cross_attention :
161+ logger .debug (f"Applying AttentionProcessorSkipHook to '{ config .fqn } .{ i } .{ submodule_name } '" )
162+ output_fn = AttentionProcessorRegistry .get (submodule .processor .__class__ ).skip_processor_output_fn
163+ registry = HookRegistry .check_if_exists_or_initialize (submodule )
164+ hook = AttentionProcessorSkipHook (output_fn , config .skip_attention_scores )
165+ registry .register_hook (hook , name )
166+ elif config .skip_ff :
167+ for submodule_name , submodule in block .named_modules ():
168+ if isinstance (submodule , _FEEDFORWARD_CLASSES ):
169+ logger .debug (f"Applying FeedForwardSkipHook to '{ config .fqn } .{ i } .{ submodule_name } '" )
170+ registry = HookRegistry .check_if_exists_or_initialize (submodule )
171+ hook = FeedForwardSkipHook ()
172+ registry .register_hook (hook , name )
173+ else :
174+ raise ValueError (
175+ "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
176+ )
177+
178+ if not blocks_found :
179+ raise ValueError (
180+ f"Could not find any transformer blocks matching the provided indices { config .indices } and "
181+ f"fully qualified name '{ config .fqn } '. Please check the indices and fqn for correctness."
182+ )
0 commit comments