Skip to content

Commit 5c3618c

Browse files
committed
Add Protocol for core_attention
1 parent 497fbb4 commit 5c3618c

File tree

10 files changed

+73
-32
lines changed

10 files changed

+73
-32
lines changed

examples/multimodal/layer_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
113113
params={"attn_mask_type": attn_mask_type},
114114
submodules=SelfAttentionSubmodules(
115115
linear_qkv=not_none(TELayerNormColumnParallelLinear),
116-
core_attention=TEDotProductAttention,
116+
core_attention=not_none(TEDotProductAttention),
117117
linear_proj=TERowParallelLinear,
118118
q_layernorm=IdentityOp,
119119
k_layernorm=IdentityOp,
@@ -159,7 +159,7 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
159159
params={"attn_mask_type": attn_mask_type},
160160
submodules=SelfAttentionSubmodules(
161161
linear_qkv=not_none(TELayerNormColumnParallelLinear),
162-
core_attention=TEDotProductAttention,
162+
core_attention=not_none(TEDotProductAttention),
163163
linear_proj=TERowParallelLinear,
164164
),
165165
),

examples/multimodal/radio/radio_g.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_radio_g_layer_spec_te() -> ModuleSpec:
125125
params={"attn_mask_type": attn_mask_type},
126126
submodules=SelfAttentionSubmodules(
127127
linear_qkv=not_none(TELayerNormColumnParallelLinear),
128-
core_attention=TEDotProductAttention,
128+
core_attention=not_none(TEDotProductAttention),
129129
linear_proj=TERowParallelLinear,
130130
q_layernorm=IdentityOp,
131131
k_layernorm=IdentityOp,

megatron/core/extensions/kitchen.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,9 +1431,9 @@ def forward(
14311431
query: Tensor,
14321432
key: Tensor,
14331433
value: Tensor,
1434-
attention_mask: Tensor,
1435-
attn_mask_type: AttnMaskType = None,
1436-
attention_bias: Tensor = None,
1434+
attention_mask: Optional[Tensor],
1435+
attn_mask_type: Optional[AttnMaskType] = None,
1436+
attention_bias: Optional[Tensor] = None,
14371437
packed_seq_params: Optional[PackedSeqParams] = None,
14381438
):
14391439
"""Forward."""
@@ -1581,11 +1581,11 @@ def forward(
15811581
query: Tensor,
15821582
key: Tensor,
15831583
value: Tensor,
1584-
attention_mask: Tensor,
1585-
attn_mask_type: AttnMaskType = None,
1586-
attention_bias: Tensor = None,
1584+
attention_mask: Optional[Tensor],
1585+
attn_mask_type: Optional[AttnMaskType] = None,
1586+
attention_bias: Optional[Tensor] = None,
15871587
packed_seq_params: Optional[PackedSeqParams] = None,
1588-
):
1588+
) -> Tensor:
15891589
"""Forward."""
15901590
assert self.init_finished, "Must call finish_init before forward."
15911591
assert packed_seq_params is None, (
@@ -1752,7 +1752,9 @@ def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:
17521752
"""Which module to use for layer norm"""
17531753
return self.fallback.layer_norm(rms_norm=rms_norm, for_qk=for_qk)
17541754

1755-
def core_attention(self) -> type:
1755+
def core_attention(
1756+
self,
1757+
) -> type[KitchenDotProductAttention] | type[KitchenFlashAttention] | type:
17561758
"""Which module to use for attention"""
17571759
if not self.use_kitchen_attention:
17581760
log_single_rank(

megatron/core/extensions/transformer_engine.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@
6262
)
6363

6464
if TYPE_CHECKING:
65-
# For type checking,
65+
# For type checking, treat transformer_engine as always available.
6666
import transformer_engine as te
6767
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
68+
6869
HAVE_TE = True
6970
else:
7071
try:
@@ -1160,7 +1161,7 @@ def __init__(
11601161
v_channels: Optional[int] = None,
11611162
num_splits: Optional[int] = None,
11621163
cp_comm_type: str = "p2p",
1163-
pg_collection: ProcessGroupCollection = None,
1164+
pg_collection: Optional[ProcessGroupCollection] = None,
11641165
):
11651166
if not HAVE_TE:
11661167
raise ImportError(
@@ -1334,12 +1335,12 @@ def forward(
13341335
query: Tensor,
13351336
key: Tensor,
13361337
value: Tensor,
1337-
attention_mask: Tensor,
1338+
attention_mask: Optional[Tensor],
13381339
attn_mask_type: AttnMaskType,
1339-
attention_bias: Tensor = None,
1340-
packed_seq_params: PackedSeqParams = None,
1340+
attention_bias: Optional[Tensor] = None,
1341+
packed_seq_params: Optional[PackedSeqParams] = None,
13411342
num_splits: Optional[int] = None,
1342-
):
1343+
) -> torch.Tensor:
13431344
"""Forward."""
13441345
# Default to constructor-provided num_splits unless explicitly overridden
13451346
if num_splits is None:

megatron/core/models/T5/t5_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
6666
params={"attn_mask_type": AttnMaskType.padding},
6767
submodules=SelfAttentionSubmodules(
6868
linear_qkv=not_none(TELayerNormColumnParallelLinear),
69-
core_attention=TEDotProductAttention,
69+
core_attention=not_none(TEDotProductAttention),
7070
linear_proj=TERowParallelLinear,
7171
q_layernorm=IdentityOp,
7272
k_layernorm=IdentityOp,
@@ -95,7 +95,7 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
9595
params={"attn_mask_type": AttnMaskType.causal},
9696
submodules=SelfAttentionSubmodules(
9797
linear_qkv=not_none(TELayerNormColumnParallelLinear),
98-
core_attention=TEDotProductAttention,
98+
core_attention=not_none(TEDotProductAttention),
9999
linear_proj=TERowParallelLinear,
100100
q_layernorm=IdentityOp,
101101
k_layernorm=IdentityOp,

megatron/core/models/backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:
166166
return FusedLayerNorm
167167
return TENorm
168168

169-
def core_attention(self) -> type:
169+
def core_attention(self) -> type[TEDotProductAttention]:
170170
"""Which module to use for attention"""
171171
return TEDotProductAttention
172172

megatron/core/models/bert/bert_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_bert_layer_with_transformer_engine_spec():
6464
params={"attn_mask_type": AttnMaskType.padding},
6565
submodules=SelfAttentionSubmodules(
6666
linear_qkv=not_none(TELayerNormColumnParallelLinear),
67-
core_attention=TEDotProductAttention,
67+
core_attention=not_none(TEDotProductAttention),
6868
linear_proj=TERowParallelLinear,
6969
q_layernorm=IdentityOp,
7070
k_layernorm=IdentityOp,

megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _get_heterogenous_attention_spec(
121121
linear_qkv=(
122122
not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear
123123
),
124-
core_attention=TEDotProductAttention if use_te else DotProductAttention,
124+
core_attention=not_none(TEDotProductAttention) if use_te else DotProductAttention,
125125
linear_proj=TERowParallelLinear if use_te else RowParallelLinear,
126126
q_layernorm=ln,
127127
k_layernorm=ln,

megatron/core/transformer/attention.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
get_tensor_model_parallel_world_size,
2424
)
2525
from megatron.core.process_groups_config import ProcessGroupCollection
26-
from megatron.core.tensor_parallel.mappings import all_gather_last_dim_from_tensor_parallel_region
26+
from megatron.core.tensor_parallel.mappings import (
27+
all_gather_last_dim_from_tensor_parallel_region,
28+
)
2729
from megatron.core.transformer.identity_op import IdentityOp
2830
from megatron.core.transformer.module import MegatronModule
2931
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
@@ -168,14 +170,49 @@ def __call__(
168170
) -> LinearLayer: ...
169171

170172

173+
class CoreAttention(Protocol):
174+
"""Protocol for core_attention modules."""
175+
176+
def forward(
177+
self,
178+
query: Tensor,
179+
key: Tensor,
180+
value: Tensor,
181+
attention_mask: Optional[Tensor],
182+
/,
183+
*,
184+
attn_mask_type: AttnMaskType,
185+
attention_bias: Optional[Tensor],
186+
packed_seq_params: Optional[PackedSeqParams],
187+
) -> Tensor:
188+
"""Applies dot product attention."""
189+
...
190+
191+
192+
class CoreAttentionBuilder(Protocol):
193+
"""Protocol for building core_attention layers."""
194+
195+
def __call__(
196+
self,
197+
*,
198+
config: TransformerConfig,
199+
layer_number: int,
200+
attn_mask_type: AttnMaskType,
201+
attention_type: str,
202+
cp_comm_type: Optional[str],
203+
softmax_scale: Optional[float],
204+
pg_collection: Optional[ProcessGroupCollection],
205+
) -> CoreAttention: ...
206+
207+
171208
@dataclass
172209
class SelfAttentionSubmodules:
173210
"""
174211
Configuration class for specifying the submodules of a self-attention.
175212
"""
176213

177214
linear_qkv: LinearQkvBuilder
178-
core_attention: Union[ModuleSpec, type] = None
215+
core_attention: CoreAttentionBuilder
179216
linear_proj: Union[ModuleSpec, type] = None
180217
q_layernorm: Union[ModuleSpec, type] = None
181218
k_layernorm: Union[ModuleSpec, type] = None
@@ -189,7 +226,7 @@ class CrossAttentionSubmodules:
189226

190227
linear_q: LinearLayerBuilder
191228
linear_kv: LinearLayerBuilder
192-
core_attention: Union[ModuleSpec, type] = None
229+
core_attention: CoreAttentionBuilder
193230
linear_proj: Union[ModuleSpec, type] = None
194231

195232

@@ -273,8 +310,7 @@ def __init__(
273310
tmp_config.num_query_groups = world_size
274311
else:
275312
tmp_config = self.config
276-
self.core_attention = build_module(
277-
submodules.core_attention,
313+
self.core_attention = submodules.core_attention(
278314
config=tmp_config,
279315
layer_number=self.layer_number,
280316
attn_mask_type=self.attn_mask_type,
@@ -342,7 +378,7 @@ def custom_forward(*inputs):
342378
attention_mask = inputs[3]
343379
attn_mask_type = inputs[5]
344380
attn_mask_type = AttnMaskType(attn_mask_type.item())
345-
output_ = self.core_attention(
381+
output_ = apply_module(self.core_attention)(
346382
query,
347383
key,
348384
value,
@@ -381,7 +417,9 @@ def _get_pp_layer_offset_for_inference(self):
381417
), "Virtual pipeline parallelism is not supported for inference"
382418

383419
# Import here to avoid circular imports
384-
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
420+
from megatron.core.transformer.transformer_layer import (
421+
get_transformer_layer_offset,
422+
)
385423

386424
return get_transformer_layer_offset(
387425
self.config, vp_stage=None, pp_rank=get_pg_rank(self.pg_collection.pp)
@@ -400,7 +438,7 @@ def _adjust_key_value_for_inference(
400438
sequence_len_offset: Optional[int] = None,
401439
*,
402440
inference_params: Optional[BaseInferenceContext] = None,
403-
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
441+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, AttnMaskType, Tensor]:
404442
"""
405443
Saves the generated key and value tensors to the end of the buffers in inference_context.
406444
Returns the full size keys and values from the provided inference_context, as well as
@@ -1017,7 +1055,7 @@ def forward(
10171055
else:
10181056
if inference_context is None or inference_context.is_static_batching():
10191057
# Static batching attention kernel.
1020-
core_attn_out = self.core_attention(
1058+
core_attn_out = apply_module(self.core_attention)(
10211059
query,
10221060
key,
10231061
value,

megatron/core/transformer/dot_product_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def forward(
144144
query: Tensor,
145145
key: Tensor,
146146
value: Tensor,
147-
attention_mask: Tensor,
147+
attention_mask: Optional[Tensor],
148148
attn_mask_type: Optional[AttnMaskType] = None,
149149
attention_bias: Optional[Tensor] = None,
150150
packed_seq_params: Optional[PackedSeqParams] = None,

0 commit comments

Comments
 (0)