2323 get_tensor_model_parallel_world_size ,
2424)
2525from megatron .core .process_groups_config import ProcessGroupCollection
26- from megatron .core .tensor_parallel .mappings import all_gather_last_dim_from_tensor_parallel_region
26+ from megatron .core .tensor_parallel .mappings import (
27+ all_gather_last_dim_from_tensor_parallel_region ,
28+ )
2729from megatron .core .transformer .identity_op import IdentityOp
2830from megatron .core .transformer .module import MegatronModule
2931from megatron .core .transformer .spec_utils import ModuleSpec , build_module
@@ -168,14 +170,49 @@ def __call__(
168170 ) -> LinearLayer : ...
169171
170172
173+ class CoreAttention (Protocol ):
174+ """Protocol for core_attention modules."""
175+
176+ def forward (
177+ self ,
178+ query : Tensor ,
179+ key : Tensor ,
180+ value : Tensor ,
181+ attention_mask : Optional [Tensor ],
182+ / ,
183+ * ,
184+ attn_mask_type : AttnMaskType ,
185+ attention_bias : Optional [Tensor ],
186+ packed_seq_params : Optional [PackedSeqParams ],
187+ ) -> Tensor :
188+ """Applies dot product attention."""
189+ ...
190+
191+
192+ class CoreAttentionBuilder (Protocol ):
193+ """Protocol for building core_attention layers."""
194+
195+ def __call__ (
196+ self ,
197+ * ,
198+ config : TransformerConfig ,
199+ layer_number : int ,
200+ attn_mask_type : AttnMaskType ,
201+ attention_type : str ,
202+ cp_comm_type : Optional [str ],
203+ softmax_scale : Optional [float ],
204+ pg_collection : Optional [ProcessGroupCollection ],
205+ ) -> CoreAttention : ...
206+
207+
171208@dataclass
172209class SelfAttentionSubmodules :
173210 """
174211 Configuration class for specifying the submodules of a self-attention.
175212 """
176213
177214 linear_qkv : LinearQkvBuilder
178- core_attention : Union [ ModuleSpec , type ] = None
215+ core_attention : CoreAttentionBuilder
179216 linear_proj : Union [ModuleSpec , type ] = None
180217 q_layernorm : Union [ModuleSpec , type ] = None
181218 k_layernorm : Union [ModuleSpec , type ] = None
@@ -189,7 +226,7 @@ class CrossAttentionSubmodules:
189226
190227 linear_q : LinearLayerBuilder
191228 linear_kv : LinearLayerBuilder
192- core_attention : Union [ ModuleSpec , type ] = None
229+ core_attention : CoreAttentionBuilder
193230 linear_proj : Union [ModuleSpec , type ] = None
194231
195232
@@ -273,8 +310,7 @@ def __init__(
273310 tmp_config .num_query_groups = world_size
274311 else :
275312 tmp_config = self .config
276- self .core_attention = build_module (
277- submodules .core_attention ,
313+ self .core_attention = submodules .core_attention (
278314 config = tmp_config ,
279315 layer_number = self .layer_number ,
280316 attn_mask_type = self .attn_mask_type ,
@@ -342,7 +378,7 @@ def custom_forward(*inputs):
342378 attention_mask = inputs [3 ]
343379 attn_mask_type = inputs [5 ]
344380 attn_mask_type = AttnMaskType (attn_mask_type .item ())
345- output_ = self .core_attention (
381+ output_ = apply_module ( self .core_attention ) (
346382 query ,
347383 key ,
348384 value ,
@@ -381,7 +417,9 @@ def _get_pp_layer_offset_for_inference(self):
381417 ), "Virtual pipeline parallelism is not supported for inference"
382418
383419 # Import here to avoid circular imports
384- from megatron .core .transformer .transformer_layer import get_transformer_layer_offset
420+ from megatron .core .transformer .transformer_layer import (
421+ get_transformer_layer_offset ,
422+ )
385423
386424 return get_transformer_layer_offset (
387425 self .config , vp_stage = None , pp_rank = get_pg_rank (self .pg_collection .pp )
@@ -400,7 +438,7 @@ def _adjust_key_value_for_inference(
400438 sequence_len_offset : Optional [int ] = None ,
401439 * ,
402440 inference_params : Optional [BaseInferenceContext ] = None ,
403- ) -> Tuple [Tensor , Tensor , Tensor , Tensor , Tensor , Tensor ]:
441+ ) -> Tuple [Tensor , Tensor , Tensor , Tensor , AttnMaskType , Tensor ]:
404442 """
405443 Saves the generated key and value tensors to the end of the buffers in inference_context.
406444 Returns the full size keys and values from the provided inference_context, as well as
@@ -1017,7 +1055,7 @@ def forward(
10171055 else :
10181056 if inference_context is None or inference_context .is_static_batching ():
10191057 # Static batching attention kernel.
1020- core_attn_out = self .core_attention (
1058+ core_attn_out = apply_module ( self .core_attention ) (
10211059 query ,
10221060 key ,
10231061 value ,
0 commit comments