Skip to content
7 changes: 7 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def _add_network_size_args(parser):
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--attention-head-type', type=str, default='multihead',
choices=['multihead', 'multiquery'],
help='Type of attention heads. `multihead` is the standard multi-head attention.'
'`multiquery` shares the values and keys across attention heads')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
Expand Down Expand Up @@ -477,6 +481,9 @@ def _add_logging_args(parser):
help="Name of wandb entity for reporting")
group.add_argument('--wandb-project-name', type=str, default=None,
help="Name of wandb project")
group.add_argument('--transformer-timers', action='store_true',
help="If set, activate the timers within the transformer layers."
"Only for debugging, as this slows down the model.")

return parser

Expand Down
199 changes: 189 additions & 10 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_linear_layer


from .glu_activations import GLU_ACTIVATIONS
Expand Down Expand Up @@ -233,6 +233,7 @@ def forward(self, query_layer, key_layer,
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
np = query_layer.size(2)

# [b, np, sq, sk]
output_size = (query_layer.size(1),
Expand All @@ -253,6 +254,7 @@ def forward(self, query_layer, key_layer,
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
else:
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
matmul_input_buffer = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

# Raw attention scores. [b * np, sq, sk]
Expand Down Expand Up @@ -307,7 +309,7 @@ def forward(self, query_layer, key_layer,

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
np,
query_layer.size(0),
value_layer.size(3))

Expand Down Expand Up @@ -336,6 +338,127 @@ def forward(self, query_layer, key_layer,
return context_layer


class MultiQueryCoreAttention(CoreAttention):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
sq = query_layer.size(0)
bs = query_layer.size(1)
np = query_layer.size(2)

sk = key_layer.size(0)
# Only one head for key and values
assert key_layer.size(2) == 1 and value_layer.size(2) == 1

# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))

# [sq, b, np, hn] -> [b, np * sq, hn]
query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1)
# [sk, b, 1, hn] -> [b, hn, sk]
key_layer = key_layer.squeeze(2).permute(1, 2, 0)
# [sk, b, 1, hn] -> [sk, b * np, hn]
# key_layer = key_layer.expand(output_size[3], output_size[0], np, -1)
# key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1)

if alibi is None:
# preallocting input tensor: [b, np * sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor(
(bs, np * sq, sk),
query_layer.dtype, "mpu")
else:
# alibi: (batch_size * num_attention_heads, 1, max_seq_len)
# TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk)
matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk)
matmul_input_buffer = matmul_input_buffer.repeat(1, sq, 1) # [b, np * sq, sk]

if alibi is None:
# Raw attention scores. [b, np * sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer, # [b, np * sq, hn]
key_layer, # [b, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
else:
if not hasattr(self, "logged_alibi"):
print("Using Alibi.")
self.logged_alibi = True

if self.apply_query_key_layer_scaling:
beta = 1.0 / self.layer_number
else:
beta = 1.0

matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer,
key_layer,
beta=beta, alpha=(1.0 / self.norm_factor))

# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(bs, np, sq, sk)

# ===========================
# Attention probs and dropout
# ===========================

# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.

if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)

# =========================
# Context layer. [sq, b, hp]
# =========================

# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]

# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
np,
query_layer.size(0),
value_layer.size(3))

# [sk, b, 1, hn] -> [b, sk, hn]
value_layer = value_layer.squeeze(2).transpose(0, 1)

# change view [b, np * sq, sk]
attention_probs = attention_probs.view(bs, np * sq, -1)

# matmul: [b, np * sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)

# change view [b, np, sq, hn]
context_layer = context_layer.view(bs, np, sq, -1)

# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)

return context_layer


class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Expand All @@ -353,6 +476,7 @@ def __init__(self, init_method,
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.attention_head_type = args.attention_head_type

projection_size = args.kv_channels * args.num_attention_heads

Expand All @@ -364,13 +488,28 @@ def __init__(self, init_method,
args.num_attention_heads, world_size)

# Strided linear layer.
if attention_type == AttnType.self_attn:
if attention_type == AttnType.self_attn and self.attention_head_type == 'multihead':
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
# TODO: Find a way to merge the query and key-value computations?
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
# In MultiQuery attention, keys and values are shared across heads
# Use args.kv_channels instead of projection_size
# No `.fork()` so the rng tracker is shared across tensor-parallel processes.
# with mpu.get_cuda_rng_tracker():
self.key_value = get_linear_layer(
args.hidden_size,
2 * args.kv_channels,
init_method=init_method)
elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead':
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
Expand All @@ -383,9 +522,14 @@ def __init__(self, init_method,
2 * projection_size,
gather_output=False,
init_method=init_method)
else:
raise NotImplementedError("Multiquery attention not implemented for cross-attention.")

self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
if self.attention_head_type == 'multihead':
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
else:
self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

# Output.
Expand Down Expand Up @@ -419,15 +563,15 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.num_attention_heads_per_partition if self.attention_head_type == "multihead" else 1,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())


def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None, alibi=None):
# hidden_states: [sq, b, h]

# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
Expand All @@ -449,7 +593,7 @@ def forward(self, hidden_states, attention_mask,
# Query, Key, and Value
# =====================

if self.attention_type == AttnType.self_attn:
if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead':
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

Expand All @@ -463,6 +607,35 @@ def forward(self, hidden_states, attention_mask,
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery':
# Attention heads [sq, b, h] --> [sq, b, (2 * hn)]
mixed_kv_layer = self.key_value(hidden_states)

# [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn]
# new_tensor_shape = mixed_kv_layer.size()[:-1] + \
# (self.num_attention_heads_per_partition,
# 2 * self.hidden_size_per_attention_head)
# mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape)

# [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(1,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)

# Attention head [sq, b, h] --> [sq, b, np * hn]
query_layer, _ = self.query(hidden_states)
# [sq, b, np * hn] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)

# [sq, b, np, hn] -> [b, np * sq, hn]
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
Expand All @@ -489,6 +662,7 @@ def forward(self, hidden_states, attention_mask,
# Adjust key and value for inference
# ==================================


if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
Expand Down Expand Up @@ -520,7 +694,6 @@ def forward(self, hidden_states, attention_mask,
# =================
# Output. [sq, b, h]
# =================

output, bias = self.dense(context_layer)

return output, bias
Expand Down Expand Up @@ -963,6 +1136,10 @@ def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
timers = get_timers()
args = get_args()

if args.transformer_timers: timers("Transformer forward").start()

# Checks.
if inference_params:
Expand Down Expand Up @@ -1020,4 +1197,6 @@ def forward(self, hidden_states, attention_mask,
if self.post_process and self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)

if args.transformer_timers: timers("Transformer forward").stop()

return hidden_states
3 changes: 2 additions & 1 deletion megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def backward(ctx, grad_output):
handle.wait()

# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
# TODO: Is the reshape preventing us from getting a speedup here?
grad_output = grad_output.reshape(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
Expand Down
20 changes: 20 additions & 0 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,19 @@ def allreduce_embedding_grads(self, args):
"""All-reduce both word and position embeddings."""
self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args)

def allreduce_key_value_grads(self, args):
# TODO: models[0] ?
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
for layer in unwrapped_model.language_model.encoder.layers:
kv_weight = layer.self_attention.key_value.weight
if args.DDP_impl == 'local':
grad = kv_weight.main_grad
else:
grad = kv_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_tensor_model_parallel_group())


def allreduce_layernorm_grads(self, args):
Expand Down Expand Up @@ -310,6 +323,13 @@ def reduce_model_grads(self, args, timers):
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()

# All-reduce key-value grads if needed.
if args.attention_head_type == "multiquery":
timers('backward-key-value-all-reduce').start()
self.allreduce_key_value_grads(args)
timers('backward-key-value-all-reduce').stop()



class MixedPrecisionOptimizer(MegatronOptimizer):
"""Base class for both the float-16 and the distributed optimizer.
Expand Down
Loading