99
1010from megatron .core import tensor_parallel
1111from megatron .core .inference .contexts import BaseInferenceContext
12+ from megatron .core .jit import jit_fuser
1213from megatron .core .models .common .embeddings .rope_utils import (
1314 apply_rotary_pos_emb ,
1415 apply_rotary_pos_emb_with_cos_sin ,
@@ -504,7 +505,9 @@ def _adjust_key_value_for_inference(
504505 return query , key , value , rotary_pos_emb , attn_mask_type , block_table
505506
506507 @abstractmethod
507- def get_query_key_value_tensors (self , hidden_states , key_value_states , split_qkv = True ):
508+ def get_query_key_value_tensors (
509+ self , hidden_states , key_value_states , output_gate = False , split_qkv = True
510+ ):
508511 """
509512 This method needs to be implemented based on whether the derived class
510513 is "self-attn" or "cross-attn".
@@ -803,13 +806,24 @@ def forward(
803806 ), "fused_single_qkv_rope requested but not available/supported for the config."
804807
805808 qkv_output = self .get_query_key_value_tensors (
806- hidden_states , key_value_states , split_qkv = split_qkv
809+ hidden_states ,
810+ key_value_states ,
811+ split_qkv = split_qkv ,
812+ output_gate = self .config .attention_output_gate ,
807813 )
808814 attn_mask_type = self .attn_mask_type
809815 block_table = None
816+ gate = None
810817 if split_qkv :
811- query , key , value = qkv_output
818+ if self .config .attention_output_gate :
819+ query , key , value , gate = qkv_output
820+ else :
821+ query , key , value = qkv_output
822+ mixed_qkv = qkv_split_arg_list = None
812823 else :
824+ assert (
825+ not self .config .attention_output_gate
826+ ), "attention_output_gate is not supported for unsplit mixed_qkv tensor."
813827 mixed_qkv , qkv_split_arg_list = qkv_output
814828 nvtx_range_pop (suffix = "qkv" )
815829
@@ -989,6 +1003,12 @@ def forward(
9891003 core_attn_out = core_attn_out .reshape (core_attn_out .size (0 ), 1 , - 1 )
9901004 nvtx_range_pop (suffix = "core_attention" )
9911005
1006+ # Output gate
1007+ if gate is not None :
1008+ nvtx_range_push (suffix = "output_gate" )
1009+ core_attn_out = self ._apply_output_gate (core_attn_out , gate )
1010+ nvtx_range_pop (suffix = "output_gate" )
1011+
9921012 # =================
9931013 # Output. [sq, b, h]
9941014 # =================
@@ -999,6 +1019,15 @@ def forward(
9991019
10001020 return output , bias
10011021
1022+ @jit_fuser
1023+ def _apply_output_gate (self , x , gate ):
1024+ x_dtype = x .dtype
1025+ gate = gate .contiguous ()
1026+ gate = gate .view (* x .shape )
1027+ x = x * torch .sigmoid (gate .float ())
1028+ x = x .to (x_dtype )
1029+ return x
1030+
10021031 def set_for_recompute_input_layernorm (self ):
10031032 """Set the attention layer for recompute input_layernorm. Only needed for fp8."""
10041033 raise NotImplementedError ("set_for_recompute_input_layernorm is not implemented." )
@@ -1037,10 +1066,13 @@ def __init__(
10371066 pg_collection = pg_collection ,
10381067 )
10391068
1069+ self .linear_qkv_out_dim = self .query_projection_size + 2 * self .kv_projection_size
1070+ if self .config .attention_output_gate :
1071+ self .linear_qkv_out_dim += self .config .kv_channels * self .config .num_attention_heads
10401072 self .linear_qkv = build_module (
10411073 submodules .linear_qkv ,
10421074 self .config .hidden_size ,
1043- self .query_projection_size + 2 * self . kv_projection_size ,
1075+ self .linear_qkv_out_dim ,
10441076 config = self .config ,
10451077 init_method = self .config .init_method ,
10461078 gather_output = False ,
@@ -1142,13 +1174,23 @@ def _compare(srcs, tgts, names, parallelism):
11421174 "TP" ,
11431175 )
11441176
1145- def get_query_key_value_tensors (self , hidden_states , key_value_states = None , split_qkv = True ):
1177+ def get_query_key_value_tensors (
1178+ self , hidden_states , key_value_states = None , output_gate = False , split_qkv = True
1179+ ):
11461180 """
1147- Derives `query`, `key` and `value` tensors from `hidden_states`. If `split_qkv=False`, then
1148- the unsplit mixed_qkv tensor is returned.
1181+ Derives `query`, `key` and `value` tensors from `hidden_states`.
1182+ If `output_gate` is True, then also derives `gate` tensor.
1183+ If `split_qkv=False`, then the unsplit mixed_qkv tensor is returned.
11491184 """
1150- # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
1185+ # If no output gate: Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
1186+ # If have output gate: Attention heads [sq, b, h] --> [sq, b, ng * (2 * np/ng + 2) * hn)]
11511187 mixed_qkv , _ = self .linear_qkv (hidden_states )
1188+ num_query_heads_per_group = (
1189+ self .num_attention_heads_per_partition // self .num_query_groups_per_partition
1190+ )
1191+ num_qkv_heads_per_group = num_query_heads_per_group + 2
1192+ if output_gate :
1193+ num_qkv_heads_per_group += num_query_heads_per_group
11521194
11531195 if self .config .num_query_groups < self .world_size :
11541196 # Note that weights are interleaved in the following manner:
@@ -1170,42 +1212,51 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None, spli
11701212 size = mixed_qkv .size ()[- 1 ] // self .config .num_query_groups
11711213 mixed_qkv = mixed_qkv [:, :, idx * size : (idx + 1 ) * size ]
11721214
1173- # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
1215+ # If no output gate: [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
1216+ # If have output gate: [sq, b, hp] --> [sq, b, ng, (2 * np/ng + 2) * hn]
11741217 new_tensor_shape = mixed_qkv .size ()[:- 1 ] + (
11751218 self .num_query_groups_per_partition ,
1176- (
1177- (self .num_attention_heads_per_partition // self .num_query_groups_per_partition + 2 )
1178- * self .hidden_size_per_attention_head
1179- ),
1219+ num_qkv_heads_per_group * self .hidden_size_per_attention_head ,
11801220 )
11811221 mixed_qkv = mixed_qkv .view (* new_tensor_shape )
11821222
1183- split_arg_list = [
1184- (
1185- self .num_attention_heads_per_partition
1186- // self .num_query_groups_per_partition
1187- * self .hidden_size_per_attention_head
1188- ),
1189- self .hidden_size_per_attention_head ,
1190- self .hidden_size_per_attention_head ,
1191- ]
1192-
1193- # Return unsplit mixed_qkv and split_arg_list
1194- if not split_qkv :
1195- return mixed_qkv , split_arg_list
1196-
1197- if SplitAlongDim is not None :
1223+ # Split the tensor into query, gate, key, and value.
1224+ if output_gate :
1225+ if not split_qkv :
1226+ raise ValueError ("split_qkv not supported for gated attention yet." )
1227+ # If have output gate: [sq, b, ng, (2 * np/ng + 2) * hn]
1228+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, np/ng * hn],
1229+ # [sq, b, ng, hn], [sq, b, ng, hn]
1230+ split_arg_list = [
1231+ num_query_heads_per_group * self .hidden_size_per_attention_head ,
1232+ num_query_heads_per_group * self .hidden_size_per_attention_head ,
1233+ self .hidden_size_per_attention_head ,
1234+ self .hidden_size_per_attention_head ,
1235+ ]
11981236
1199- # [sq, b, ng, (np/ng + 2) * hn]
1200- # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1201- (query , key , value ) = SplitAlongDim (mixed_qkv , 3 , split_arg_list )
1237+ if SplitAlongDim is not None :
1238+ (query , gate , key , value ) = SplitAlongDim (mixed_qkv , 3 , split_arg_list )
1239+ else :
1240+ (query , gate , key , value ) = torch .split (mixed_qkv , split_arg_list , dim = 3 )
12021241 else :
1242+ # If no output gate: [sq, b, ng, (np/ng + 2) * hn]
1243+ # --> [sq, b, ng, np/ng * hn], None, [sq, b, ng, hn], [sq, b, ng, hn]
1244+ split_arg_list = [
1245+ num_query_heads_per_group * self .hidden_size_per_attention_head ,
1246+ self .hidden_size_per_attention_head ,
1247+ self .hidden_size_per_attention_head ,
1248+ ]
12031249
1204- # [sq, b, ng, (np/ng + 2) * hn]
1205- # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1206- ( query , key , value ) = torch . split ( mixed_qkv , split_arg_list , dim = 3 )
1250+ # Return unsplit mixed_qkv and split_arg_list
1251+ if not split_qkv :
1252+ return mixed_qkv , split_arg_list
12071253
1208- # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1254+ if SplitAlongDim is not None :
1255+ (query , key , value ) = SplitAlongDim (mixed_qkv , 3 , split_arg_list )
1256+ else :
1257+ (query , key , value ) = torch .split (mixed_qkv , split_arg_list , dim = 3 )
1258+
1259+ # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
12091260 query = query .reshape (query .size (0 ), query .size (1 ), - 1 , self .hidden_size_per_attention_head )
12101261
12111262 if self .config .num_query_groups < self .world_size :
@@ -1229,6 +1280,11 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None, spli
12291280 if self .config .test_mode :
12301281 self .run_realtime_tests ()
12311282
1283+ if output_gate :
1284+ # Gate [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1285+ gate = gate .reshape (* gate .shape [:2 ], - 1 , self .hidden_size_per_attention_head )
1286+ return query , key , value , gate
1287+
12321288 return query , key , value
12331289
12341290 def backward_dw (self ) -> NoReturn :
@@ -1402,12 +1458,16 @@ def __init__(
14021458 is_expert = False ,
14031459 )
14041460
1405- def get_query_key_value_tensors (self , hidden_states , key_value_states , split_qkv = True ):
1461+ def get_query_key_value_tensors (
1462+ self , hidden_states , key_value_states , output_gate = False , split_qkv = True
1463+ ):
14061464 """
14071465 Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
14081466 from `key_value_states`.
14091467 """
14101468 assert split_qkv , "split_qkv must be True for CrossAttention"
1469+ assert not output_gate , "Output gate is not supported in cross attention for now."
1470+
14111471 # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
14121472 mixed_kv , _ = self .linear_kv (key_value_states )
14131473
0 commit comments