Skip to content

Commit 90e685b

Browse files
authored
Replaces ModuleSpec with Protocols for some of the inputs to SelfAttention/CrossAttention (#2761)
1 parent 096dbeb commit 90e685b

File tree

16 files changed

+318
-110
lines changed

16 files changed

+318
-110
lines changed

examples/multimodal/layer_specs.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import torch
33

44
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
5+
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
6+
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
7+
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
8+
from megatron.core.ssm.mlp_layer import MLPLayer
59
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
610
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
711
from megatron.core.transformer.dot_product_attention import DotProductAttention
@@ -10,10 +14,7 @@
1014
from megatron.core.transformer.mlp import MLP, MLPSubmodules
1115
from megatron.core.transformer.spec_utils import ModuleSpec
1216
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
13-
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
14-
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
15-
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
16-
from megatron.core.ssm.mlp_layer import MLPLayer
17+
from megatron.core.typed_torch import not_none
1718

1819
try:
1920
from megatron.core.extensions.transformer_engine import (
@@ -26,6 +27,13 @@
2627

2728
HAVE_TE = True
2829
except ImportError:
30+
(
31+
TEColumnParallelLinear,
32+
TEDotProductAttention,
33+
TELayerNormColumnParallelLinear,
34+
TENorm,
35+
TERowParallelLinear,
36+
) = (None, None, None, None, None)
2937
HAVE_TE = False
3038

3139
try:
@@ -54,12 +62,8 @@ def get_layer_spec(is_vit, normalization) -> ModuleSpec:
5462
norm = TENorm
5563
else:
5664
version = torch.__version__.split('.')
57-
version_geq_2_4 = (
58-
int(TORCH_VERSION[0]) > 2
59-
or (
60-
int(TORCH_VERSION[0]) == 2
61-
and int(TORCH_VERSION[1]) >= 4
62-
)
65+
version_geq_2_4 = int(TORCH_VERSION[0]) > 2 or (
66+
int(TORCH_VERSION[0]) == 2 and int(TORCH_VERSION[1]) >= 4
6367
)
6468
assert version_geq_2_4, "Torch version >= 2.4.0 is required for RMSNorm"
6569
if HAVE_APEX:
@@ -108,8 +112,8 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
108112
module=SelfAttention,
109113
params={"attn_mask_type": attn_mask_type},
110114
submodules=SelfAttentionSubmodules(
111-
linear_qkv=TELayerNormColumnParallelLinear,
112-
core_attention=TEDotProductAttention,
115+
linear_qkv=not_none(TELayerNormColumnParallelLinear),
116+
core_attention=not_none(TEDotProductAttention),
113117
linear_proj=TERowParallelLinear,
114118
q_layernorm=IdentityOp,
115119
k_layernorm=IdentityOp,
@@ -122,6 +126,7 @@ def get_layer_spec_te(is_vit=False, padding=False) -> ModuleSpec:
122126
),
123127
)
124128

129+
125130
def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
126131
attn_mask_type = AttnMaskType.causal
127132
# Padding mask is needed for e.g. Context Parallel.
@@ -153,8 +158,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
153158
module=SelfAttention,
154159
params={"attn_mask_type": attn_mask_type},
155160
submodules=SelfAttentionSubmodules(
156-
linear_qkv=TELayerNormColumnParallelLinear,
157-
core_attention=TEDotProductAttention,
161+
linear_qkv=not_none(TELayerNormColumnParallelLinear),
162+
core_attention=not_none(TEDotProductAttention),
158163
linear_proj=TERowParallelLinear,
159164
),
160165
),
@@ -170,7 +175,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
170175
mlp=ModuleSpec(
171176
module=MLP,
172177
submodules=MLPSubmodules(
173-
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
178+
linear_fc1=TELayerNormColumnParallelLinear,
179+
linear_fc2=TERowParallelLinear,
174180
),
175181
),
176182
mlp_bda=get_bias_dropout_add,
@@ -179,6 +185,7 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
179185
),
180186
)
181187

188+
182189
def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
183190
# Dense MLP w/ or w/o TE modules.
184191
return ModuleSpec(

examples/multimodal/nvlm/internvit.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
import torch
1616

17-
from megatron.core.utils import divide
17+
from examples.multimodal.layer_scaling import (
18+
LayerScalingTransformerLayer,
19+
get_bias_dropout_add_layer_scaling,
20+
)
1821
from megatron.core.extensions.transformer_engine import (
1922
TEColumnParallelLinear,
2023
TEDotProductAttention,
@@ -35,9 +38,7 @@
3538
from megatron.core.transformer.transformer_config import TransformerConfig
3639
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
3740
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
38-
39-
from examples.multimodal.layer_scaling import LayerScalingTransformerLayer, get_bias_dropout_add_layer_scaling
40-
41+
from megatron.core.utils import divide
4142

4243
try:
4344
import apex
@@ -128,10 +129,14 @@ def _gather_var(self, input_, max_dim):
128129

129130
if rank < valid_ranks: # Ranks without any dummy attention heads.
130131
var = input_.sum(-1, keepdim=True)
131-
elif rank == valid_ranks: # The only rank which may contain 'residual_heads' dummy attention heads.
132+
elif (
133+
rank == valid_ranks
134+
): # The only rank which may contain 'residual_heads' dummy attention heads.
132135
var = input_[..., :max_dim].sum(-1, keepdim=True)
133136
else:
134-
var = input_.sum(-1, keepdim=True) * 0.0 # All heads in these ranks are dummy heads: Zero-out.
137+
var = (
138+
input_.sum(-1, keepdim=True) * 0.0
139+
) # All heads in these ranks are dummy heads: Zero-out.
135140

136141
tensor_list = [torch.empty_like(var) for _ in range(world_size)]
137142
tensor_list[rank] = var
@@ -175,8 +180,7 @@ def __init__(
175180
# Need to override linear_qkv, q_layernorm and k_layernorm.
176181
qkv_bias = False
177182

178-
self.linear_qkv = build_module(
179-
submodules.linear_qkv,
183+
self.linear_qkv = submodules.linear_qkv(
180184
self.config.hidden_size,
181185
self.query_projection_size + 2 * self.kv_projection_size,
182186
config=self.config,
@@ -256,6 +260,7 @@ def get_internvit_layer_spec(use_te) -> ModuleSpec:
256260
),
257261
)
258262

263+
259264
def get_internvit300M_layer_spec(use_te) -> ModuleSpec:
260265
mlp = get_mlp_module_spec(use_te) # no norm
261266

examples/multimodal/radio/radio_g.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
import torch
55

6+
from examples.multimodal.layer_scaling import (
7+
LayerScalingTransformerLayer,
8+
get_bias_dropout_add_layer_scaling,
9+
)
610
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
711
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
812
from megatron.core.transformer.dot_product_attention import DotProductAttention
@@ -11,7 +15,7 @@
1115
from megatron.core.transformer.mlp import MLP, MLPSubmodules
1216
from megatron.core.transformer.spec_utils import ModuleSpec
1317
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
14-
from examples.multimodal.layer_scaling import LayerScalingTransformerLayer, get_bias_dropout_add_layer_scaling
18+
from megatron.core.typed_torch import not_none
1519

1620
try:
1721
from megatron.core.extensions.transformer_engine import (
@@ -24,6 +28,13 @@
2428

2529
HAVE_TE = True
2630
except ImportError:
31+
(
32+
TEColumnParallelLinear,
33+
TEDotProductAttention,
34+
TELayerNormColumnParallelLinear,
35+
TENorm,
36+
TERowParallelLinear,
37+
) = (None, None, None, None, None)
2738
HAVE_TE = False
2839

2940
try:
@@ -113,8 +124,8 @@ def get_radio_g_layer_spec_te() -> ModuleSpec:
113124
module=SelfAttention,
114125
params={"attn_mask_type": attn_mask_type},
115126
submodules=SelfAttentionSubmodules(
116-
linear_qkv=TELayerNormColumnParallelLinear,
117-
core_attention=TEDotProductAttention,
127+
linear_qkv=not_none(TELayerNormColumnParallelLinear),
128+
core_attention=not_none(TEDotProductAttention),
118129
linear_proj=TERowParallelLinear,
119130
q_layernorm=IdentityOp,
120131
k_layernorm=IdentityOp,

megatron/core/extensions/kitchen.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,9 +1431,9 @@ def forward(
14311431
query: Tensor,
14321432
key: Tensor,
14331433
value: Tensor,
1434-
attention_mask: Tensor,
1435-
attn_mask_type: AttnMaskType = None,
1436-
attention_bias: Tensor = None,
1434+
attention_mask: Optional[Tensor],
1435+
attn_mask_type: Optional[AttnMaskType] = None,
1436+
attention_bias: Optional[Tensor] = None,
14371437
packed_seq_params: Optional[PackedSeqParams] = None,
14381438
):
14391439
"""Forward."""
@@ -1581,11 +1581,11 @@ def forward(
15811581
query: Tensor,
15821582
key: Tensor,
15831583
value: Tensor,
1584-
attention_mask: Tensor,
1585-
attn_mask_type: AttnMaskType = None,
1586-
attention_bias: Tensor = None,
1584+
attention_mask: Optional[Tensor],
1585+
attn_mask_type: Optional[AttnMaskType] = None,
1586+
attention_bias: Optional[Tensor] = None,
15871587
packed_seq_params: Optional[PackedSeqParams] = None,
1588-
):
1588+
) -> Tensor:
15891589
"""Forward."""
15901590
assert self.init_finished, "Must call finish_init before forward."
15911591
assert packed_seq_params is None, (
@@ -1725,7 +1725,7 @@ def __init__(
17251725
self.use_kitchen_attention = use_kitchen_attention
17261726
self.kitchen_attention_backend = kitchen_attention_backend
17271727

1728-
def column_parallel_linear(self) -> type:
1728+
def column_parallel_linear(self) -> type[KitchenColumnParallelLinear]:
17291729
"""Which column parallel linear module kitchen backend uses"""
17301730
return KitchenColumnParallelLinear
17311731

@@ -1744,15 +1744,17 @@ def fuse_layernorm_and_linear(self) -> bool:
17441744
# explicitly about whether to include a norm.
17451745
return self.fallback.fuse_layernorm_and_linear()
17461746

1747-
def column_parallel_layer_norm_linear(self) -> Optional[type]:
1747+
def column_parallel_layer_norm_linear(self) -> type[KitchenLayerNormColumnParallelLinear]:
17481748
"""Which module for sequential layernorm and linear"""
17491749
return KitchenLayerNormColumnParallelLinear
17501750

17511751
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:
17521752
"""Which module to use for layer norm"""
17531753
return self.fallback.layer_norm(rms_norm=rms_norm, for_qk=for_qk)
17541754

1755-
def core_attention(self) -> type:
1755+
def core_attention(
1756+
self,
1757+
) -> type[KitchenDotProductAttention] | type[KitchenFlashAttention] | type:
17561758
"""Which module to use for attention"""
17571759
if not self.use_kitchen_attention:
17581760
log_single_rank(

megatron/core/extensions/transformer_engine.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pickle
99
import warnings
1010
from contextlib import nullcontext
11-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
11+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
1212

1313
import torch
1414
import torch.nn.functional as F
@@ -64,10 +64,17 @@
6464

6565
HAVE_TE = True
6666
except ImportError:
67-
from unittest.mock import MagicMock
67+
if TYPE_CHECKING:
68+
# For type checking, treat transformer_engine as always available.
69+
import transformer_engine as te
70+
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
6871

69-
te = MagicMock()
70-
HAVE_TE = False
72+
HAVE_TE = True
73+
else:
74+
from unittest.mock import MagicMock
75+
76+
te = MagicMock()
77+
HAVE_TE = False
7178

7279
_TE_CONFIG_TYPE_KEY = "transformer_engine_config_type"
7380

@@ -1152,8 +1159,8 @@ def __init__(
11521159
k_channels: Optional[int] = None,
11531160
v_channels: Optional[int] = None,
11541161
num_splits: Optional[int] = None,
1155-
cp_comm_type: str = "p2p",
1156-
pg_collection: ProcessGroupCollection = None,
1162+
cp_comm_type: Optional[str] = "p2p",
1163+
pg_collection: Optional[ProcessGroupCollection] = None,
11571164
):
11581165
if not HAVE_TE:
11591166
raise ImportError(
@@ -1328,12 +1335,12 @@ def forward(
13281335
query: Tensor,
13291336
key: Tensor,
13301337
value: Tensor,
1331-
attention_mask: Tensor,
1338+
attention_mask: Optional[Tensor],
13321339
attn_mask_type: AttnMaskType,
1333-
attention_bias: Tensor = None,
1334-
packed_seq_params: PackedSeqParams = None,
1340+
attention_bias: Optional[Tensor] = None,
1341+
packed_seq_params: Optional[PackedSeqParams] = None,
13351342
num_splits: Optional[int] = None,
1336-
):
1343+
) -> torch.Tensor:
13371344
"""Forward."""
13381345
if packed_seq_params is not None:
13391346
# If Dynamic CP group is provided, update TE DPA CP group

megatron/core/models/T5/t5_spec.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from megatron.core.transformer.spec_utils import ModuleSpec
1515
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
1616
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
17+
from megatron.core.typed_torch import not_none
1718

1819
try:
1920
import transformer_engine as te # pylint: disable=unused-import
@@ -28,6 +29,13 @@
2829

2930
HAVE_TE = True
3031
except ImportError:
32+
(
33+
TEColumnParallelLinear,
34+
TEDotProductAttention,
35+
TELayerNormColumnParallelLinear,
36+
TENorm,
37+
TERowParallelLinear,
38+
) = (None, None, None, None, None)
3139
HAVE_TE = False
3240

3341
try:
@@ -57,8 +65,8 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
5765
module=SelfAttention,
5866
params={"attn_mask_type": AttnMaskType.padding},
5967
submodules=SelfAttentionSubmodules(
60-
linear_qkv=TELayerNormColumnParallelLinear,
61-
core_attention=TEDotProductAttention,
68+
linear_qkv=not_none(TELayerNormColumnParallelLinear),
69+
core_attention=not_none(TEDotProductAttention),
6270
linear_proj=TERowParallelLinear,
6371
q_layernorm=IdentityOp,
6472
k_layernorm=IdentityOp,
@@ -86,8 +94,8 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
8694
module=SelfAttention,
8795
params={"attn_mask_type": AttnMaskType.causal},
8896
submodules=SelfAttentionSubmodules(
89-
linear_qkv=TELayerNormColumnParallelLinear,
90-
core_attention=TEDotProductAttention,
97+
linear_qkv=not_none(TELayerNormColumnParallelLinear),
98+
core_attention=not_none(TEDotProductAttention),
9199
linear_proj=TERowParallelLinear,
92100
q_layernorm=IdentityOp,
93101
k_layernorm=IdentityOp,
@@ -99,9 +107,9 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
99107
module=CrossAttention,
100108
params={"attn_mask_type": AttnMaskType.padding},
101109
submodules=CrossAttentionSubmodules(
102-
linear_q=TEColumnParallelLinear,
103-
linear_kv=TEColumnParallelLinear,
104-
core_attention=TEDotProductAttention,
110+
linear_q=not_none(TEColumnParallelLinear),
111+
linear_kv=not_none(TEColumnParallelLinear),
112+
core_attention=not_none(TEDotProductAttention),
105113
linear_proj=TERowParallelLinear,
106114
),
107115
),

megatron/core/models/backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def fuse_layernorm_and_linear(self) -> bool:
153153
"""TE backend chooses a single module for layernorm and linear"""
154154
return True
155155

156-
def column_parallel_layer_norm_linear(self) -> Optional[type]:
156+
def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnParallelLinear]:
157157
"""Which module for sequential layernorm and linear"""
158158
return InferenceLayerNormColumnParallelLinear
159159

@@ -166,7 +166,7 @@ def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:
166166
return FusedLayerNorm
167167
return TENorm
168168

169-
def core_attention(self) -> type:
169+
def core_attention(self) -> type[TEDotProductAttention]:
170170
"""Which module to use for attention"""
171171
return TEDotProductAttention
172172

0 commit comments

Comments
 (0)