Skip to content

Commit 8faf282

Browse files
yuzhongw-nvidiamaanug-nv
authored andcommitted
[main] feat(moe): Support attention output gate for Qwen3-Next (3/4) (NVIDIA#2752)
1 parent 45ee0a4 commit 8faf282

File tree

4 files changed

+126
-42
lines changed

4 files changed

+126
-42
lines changed

megatron/core/transformer/attention.py

Lines changed: 96 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from megatron.core import tensor_parallel
1111
from megatron.core.inference.contexts import BaseInferenceContext
12+
from megatron.core.jit import jit_fuser
1213
from megatron.core.models.common.embeddings.rope_utils import (
1314
apply_rotary_pos_emb,
1415
apply_rotary_pos_emb_with_cos_sin,
@@ -504,7 +505,9 @@ def _adjust_key_value_for_inference(
504505
return query, key, value, rotary_pos_emb, attn_mask_type, block_table
505506

506507
@abstractmethod
507-
def get_query_key_value_tensors(self, hidden_states, key_value_states, split_qkv=True):
508+
def get_query_key_value_tensors(
509+
self, hidden_states, key_value_states, output_gate=False, split_qkv=True
510+
):
508511
"""
509512
This method needs to be implemented based on whether the derived class
510513
is "self-attn" or "cross-attn".
@@ -803,13 +806,24 @@ def forward(
803806
), "fused_single_qkv_rope requested but not available/supported for the config."
804807

805808
qkv_output = self.get_query_key_value_tensors(
806-
hidden_states, key_value_states, split_qkv=split_qkv
809+
hidden_states,
810+
key_value_states,
811+
split_qkv=split_qkv,
812+
output_gate=self.config.attention_output_gate,
807813
)
808814
attn_mask_type = self.attn_mask_type
809815
block_table = None
816+
gate = None
810817
if split_qkv:
811-
query, key, value = qkv_output
818+
if self.config.attention_output_gate:
819+
query, key, value, gate = qkv_output
820+
else:
821+
query, key, value = qkv_output
822+
mixed_qkv = qkv_split_arg_list = None
812823
else:
824+
assert (
825+
not self.config.attention_output_gate
826+
), "attention_output_gate is not supported for unsplit mixed_qkv tensor."
813827
mixed_qkv, qkv_split_arg_list = qkv_output
814828
nvtx_range_pop(suffix="qkv")
815829

@@ -989,6 +1003,12 @@ def forward(
9891003
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
9901004
nvtx_range_pop(suffix="core_attention")
9911005

1006+
# Output gate
1007+
if gate is not None:
1008+
nvtx_range_push(suffix="output_gate")
1009+
core_attn_out = self._apply_output_gate(core_attn_out, gate)
1010+
nvtx_range_pop(suffix="output_gate")
1011+
9921012
# =================
9931013
# Output. [sq, b, h]
9941014
# =================
@@ -999,6 +1019,15 @@ def forward(
9991019

10001020
return output, bias
10011021

1022+
@jit_fuser
1023+
def _apply_output_gate(self, x, gate):
1024+
x_dtype = x.dtype
1025+
gate = gate.contiguous()
1026+
gate = gate.view(*x.shape)
1027+
x = x * torch.sigmoid(gate.float())
1028+
x = x.to(x_dtype)
1029+
return x
1030+
10021031
def set_for_recompute_input_layernorm(self):
10031032
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
10041033
raise NotImplementedError("set_for_recompute_input_layernorm is not implemented.")
@@ -1037,10 +1066,13 @@ def __init__(
10371066
pg_collection=pg_collection,
10381067
)
10391068

1069+
self.linear_qkv_out_dim = self.query_projection_size + 2 * self.kv_projection_size
1070+
if self.config.attention_output_gate:
1071+
self.linear_qkv_out_dim += self.config.kv_channels * self.config.num_attention_heads
10401072
self.linear_qkv = build_module(
10411073
submodules.linear_qkv,
10421074
self.config.hidden_size,
1043-
self.query_projection_size + 2 * self.kv_projection_size,
1075+
self.linear_qkv_out_dim,
10441076
config=self.config,
10451077
init_method=self.config.init_method,
10461078
gather_output=False,
@@ -1142,13 +1174,23 @@ def _compare(srcs, tgts, names, parallelism):
11421174
"TP",
11431175
)
11441176

1145-
def get_query_key_value_tensors(self, hidden_states, key_value_states=None, split_qkv=True):
1177+
def get_query_key_value_tensors(
1178+
self, hidden_states, key_value_states=None, output_gate=False, split_qkv=True
1179+
):
11461180
"""
1147-
Derives `query`, `key` and `value` tensors from `hidden_states`. If `split_qkv=False`, then
1148-
the unsplit mixed_qkv tensor is returned.
1181+
Derives `query`, `key` and `value` tensors from `hidden_states`.
1182+
If `output_gate` is True, then also derives `gate` tensor.
1183+
If `split_qkv=False`, then the unsplit mixed_qkv tensor is returned.
11491184
"""
1150-
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
1185+
# If no output gate: Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
1186+
# If have output gate: Attention heads [sq, b, h] --> [sq, b, ng * (2 * np/ng + 2) * hn)]
11511187
mixed_qkv, _ = self.linear_qkv(hidden_states)
1188+
num_query_heads_per_group = (
1189+
self.num_attention_heads_per_partition // self.num_query_groups_per_partition
1190+
)
1191+
num_qkv_heads_per_group = num_query_heads_per_group + 2
1192+
if output_gate:
1193+
num_qkv_heads_per_group += num_query_heads_per_group
11521194

11531195
if self.config.num_query_groups < self.world_size:
11541196
# Note that weights are interleaved in the following manner:
@@ -1170,42 +1212,51 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None, spli
11701212
size = mixed_qkv.size()[-1] // self.config.num_query_groups
11711213
mixed_qkv = mixed_qkv[:, :, idx * size : (idx + 1) * size]
11721214

1173-
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
1215+
# If no output gate: [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
1216+
# If have output gate: [sq, b, hp] --> [sq, b, ng, (2 * np/ng + 2) * hn]
11741217
new_tensor_shape = mixed_qkv.size()[:-1] + (
11751218
self.num_query_groups_per_partition,
1176-
(
1177-
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
1178-
* self.hidden_size_per_attention_head
1179-
),
1219+
num_qkv_heads_per_group * self.hidden_size_per_attention_head,
11801220
)
11811221
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
11821222

1183-
split_arg_list = [
1184-
(
1185-
self.num_attention_heads_per_partition
1186-
// self.num_query_groups_per_partition
1187-
* self.hidden_size_per_attention_head
1188-
),
1189-
self.hidden_size_per_attention_head,
1190-
self.hidden_size_per_attention_head,
1191-
]
1192-
1193-
# Return unsplit mixed_qkv and split_arg_list
1194-
if not split_qkv:
1195-
return mixed_qkv, split_arg_list
1196-
1197-
if SplitAlongDim is not None:
1223+
# Split the tensor into query, gate, key, and value.
1224+
if output_gate:
1225+
if not split_qkv:
1226+
raise ValueError("split_qkv not supported for gated attention yet.")
1227+
# If have output gate: [sq, b, ng, (2 * np/ng + 2) * hn]
1228+
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, np/ng * hn],
1229+
# [sq, b, ng, hn], [sq, b, ng, hn]
1230+
split_arg_list = [
1231+
num_query_heads_per_group * self.hidden_size_per_attention_head,
1232+
num_query_heads_per_group * self.hidden_size_per_attention_head,
1233+
self.hidden_size_per_attention_head,
1234+
self.hidden_size_per_attention_head,
1235+
]
11981236

1199-
# [sq, b, ng, (np/ng + 2) * hn]
1200-
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1201-
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
1237+
if SplitAlongDim is not None:
1238+
(query, gate, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
1239+
else:
1240+
(query, gate, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)
12021241
else:
1242+
# If no output gate: [sq, b, ng, (np/ng + 2) * hn]
1243+
# --> [sq, b, ng, np/ng * hn], None, [sq, b, ng, hn], [sq, b, ng, hn]
1244+
split_arg_list = [
1245+
num_query_heads_per_group * self.hidden_size_per_attention_head,
1246+
self.hidden_size_per_attention_head,
1247+
self.hidden_size_per_attention_head,
1248+
]
12031249

1204-
# [sq, b, ng, (np/ng + 2) * hn]
1205-
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1206-
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)
1250+
# Return unsplit mixed_qkv and split_arg_list
1251+
if not split_qkv:
1252+
return mixed_qkv, split_arg_list
12071253

1208-
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1254+
if SplitAlongDim is not None:
1255+
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
1256+
else:
1257+
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)
1258+
1259+
# Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
12091260
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
12101261

12111262
if self.config.num_query_groups < self.world_size:
@@ -1229,6 +1280,11 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None, spli
12291280
if self.config.test_mode:
12301281
self.run_realtime_tests()
12311282

1283+
if output_gate:
1284+
# Gate [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1285+
gate = gate.reshape(*gate.shape[:2], -1, self.hidden_size_per_attention_head)
1286+
return query, key, value, gate
1287+
12321288
return query, key, value
12331289

12341290
def backward_dw(self) -> NoReturn:
@@ -1402,12 +1458,16 @@ def __init__(
14021458
is_expert=False,
14031459
)
14041460

1405-
def get_query_key_value_tensors(self, hidden_states, key_value_states, split_qkv=True):
1461+
def get_query_key_value_tensors(
1462+
self, hidden_states, key_value_states, output_gate=False, split_qkv=True
1463+
):
14061464
"""
14071465
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
14081466
from `key_value_states`.
14091467
"""
14101468
assert split_qkv, "split_qkv must be True for CrossAttention"
1469+
assert not output_gate, "Output gate is not supported in cross attention for now."
1470+
14111471
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
14121472
mixed_kv, _ = self.linear_kv(key_value_states)
14131473

megatron/core/transformer/transformer_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ class TransformerConfig(ModelParallelConfig):
205205
"""Whether to log the max attention logit across whole model. Decoupled from qk_clip,
206206
defualts to False. Setting qk_clip will automatically log the max logit"""
207207

208+
attention_output_gate: bool = False
209+
"""Whether to apply output gate to the attention layers."""
210+
208211
test_mode: bool = False
209212
"""Whether to run real-time tests."""
210213

@@ -1355,6 +1358,10 @@ def __post_init__(self):
13551358
"apply_rope_fusion is not available. Please install TE >= 1.4."
13561359
)
13571360

1361+
if self.fused_single_qkv_rope:
1362+
if self.attention_output_gate:
1363+
raise ValueError("fused_single_qkv_rope does not support gated attention for now.")
1364+
13581365
if self.multi_latent_attention and self.rotary_interleaved:
13591366
raise ValueError("rotary_interleaved does not work with multi_latent_attention.")
13601367

@@ -1716,6 +1723,9 @@ def __post_init__(self):
17161723
if self.multi_latent_attention and self.apply_rope_fusion and self.rope_type != "yarn":
17171724
raise ValueError("apply_rope_fusion for MLA only works with YARN RoPE.")
17181725

1726+
if self.attention_output_gate:
1727+
raise NotImplementedError("Output gate is not supported for MLA yet.")
1728+
17191729
if self.cache_mla_latents:
17201730
assert (
17211731
self.apply_rope_fusion is False

megatron/training/arguments.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,8 @@ def _add_network_size_args(parser):
16741674
group.add_argument('--group-query-attention', action='store_true',
16751675
help='Use group-query attention.')
16761676
group.add_argument('--num-query-groups', type=int, default=1)
1677+
group.add_argument('--attention-output-gate', action='store_true',
1678+
help='Whether to apply output gate to the attention.')
16771679
group.add_argument('--softmax-type', type=str, default='vanilla',
16781680
choices=['learnable', 'vanilla', 'off-by-one'],
16791681
help='Type of softmax to use for the attention. Supports both a fixed offset and '
@@ -3138,7 +3140,7 @@ def _add_moe_args(parser):
31383140
'- A string containing a Python list expression that defines a custom pattern, e.g.: '
31393141
'"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] '
31403142
'where 1 indicates an expert layer and 0 indicates a dense layer. '
3141-
'Examples: "([0]+[1]*23)": 1 dense layer followed by 23 experts layers, '
3143+
'Examples: "([0]+[1]*23)": 1 dense layer followed by 23 expert layers, '
31423144
'"([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice.')
31433145
group.add_argument('--moe-ffn-hidden-size', type=int, default=None,
31443146
help='The hidden size of each expert\'s feed-forward network (ffn). '

tests/unit_tests/transformer/test_attention.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
HAVE_FUSED_QKV_ROPE = False
2626

2727

28+
@pytest.mark.parametrize("output_gate", [False, True])
2829
class TestParallelAttention:
2930

30-
def setup_method(self, method):
31+
@pytest.fixture(scope='function', autouse=True)
32+
def setup_method(self, output_gate):
3133
Utils.initialize_model_parallel(1, 1)
3234
model_parallel_cuda_manual_seed(123)
3335
self.transformer_config = TransformerConfig(
@@ -37,22 +39,26 @@ def setup_method(self, method):
3739
use_cpu_initialization=True,
3840
bf16=True,
3941
params_dtype=torch.bfloat16,
42+
attention_output_gate=output_gate,
4043
)
4144
self.parallel_attention = SelfAttention(
4245
self.transformer_config,
4346
get_gpt_layer_with_transformer_engine_spec().submodules.self_attention.submodules,
4447
layer_number=1,
4548
)
4649

47-
def teardown_method(self, method):
50+
def teardown_method(self):
4851
Utils.destroy_model_parallel()
4952

5053
def test_constructor(self):
5154
assert isinstance(self.parallel_attention, SelfAttention)
5255
assert self.parallel_attention.layer_number == 1
5356

5457
num_weights = sum([p.numel() for p in self.parallel_attention.parameters()])
55-
assert num_weights == 66304
58+
if self.transformer_config.attention_output_gate:
59+
assert num_weights == 82816
60+
else:
61+
assert num_weights == 66304
5662

5763
def test_cpu_forward(self):
5864
# we can't currently do this because the global memory buffer is on GPU
@@ -90,6 +96,8 @@ def test_fused_rope_gpu_forward(self, rotary_interleaved, fused_qkv_rope):
9096
self.parallel_attention.config.apply_rope_fusion = True
9197
if rotary_interleaved and not is_te_min_version("2.3.0"):
9298
pytest.skip("Only TE >= 2.3.0 supports interleaved fused RoPE.")
99+
if fused_qkv_rope and self.parallel_attention.config.attention_output_gate:
100+
pytest.skip("Fused QKV RoPE does not support gated attention for now.")
93101
if fused_qkv_rope and not HAVE_FUSED_QKV_ROPE:
94102
pytest.skip("Fused QKV RoPE not available.")
95103
self.parallel_attention.config.rotary_interleaved = rotary_interleaved
@@ -343,12 +351,15 @@ def test_clip_qk_mixed_logits(self):
343351
assert attention.core_attention.current_max_attn_logits is None
344352

345353

354+
@pytest.mark.parametrize("output_gate", [False, True])
346355
class TestSelfAttention:
347356

348-
def setup_method(self, method):
357+
@pytest.fixture(scope='function', autouse=True)
358+
def setup_method(self, output_gate):
359+
self.output_gate = output_gate
349360
Utils.destroy_model_parallel()
350361

351-
def teardown_method(self, method):
362+
def teardown_method(self):
352363
Utils.destroy_model_parallel()
353364

354365
def run_self_attention(self, pg_collection):
@@ -357,6 +368,7 @@ def run_self_attention(self, pg_collection):
357368
num_layers=2,
358369
hidden_size=128,
359370
num_attention_heads=4,
371+
attention_output_gate=self.output_gate,
360372
tensor_model_parallel_size=tensor_model_parallel_size,
361373
use_cpu_initialization=False,
362374
)

0 commit comments

Comments
 (0)