Skip to content

Commit 98fa43f

Browse files
jthakurHJay Thakur
authored andcommitted
Add batch splitting in attention layer for decode to hide NIC latency (#2334)
* Add batch splitting in attention layer for decode to hide NIC latency * Update modeling_llama.py * Update utils.py * Update modeling_llama.py * Fix code style issues: typo fix, PEP 8 formatting, and indentation - Fix typo: kv_cahe -> kv_cache in comment - Remove extra spaces before colons to follow PEP 8 - Fix indentation in examples/text-generation/utils.py - Apply automatic code formatting with ruff --------- Co-authored-by: Jay Thakur <jaythaku@habana.ai>
1 parent d8576f9 commit 98fa43f

File tree

6 files changed

+82
-16
lines changed

6 files changed

+82
-16
lines changed

examples/text-generation/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Here are a few settings you may be interested in:
133133
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
134134
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
135135
- `--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits. This feature is thoroughly tested with Llama 2 70B but may be useful for other models as well.
136+
- `--decode_attn_batch_split` specifies the number of smaller batches to split the attention and MLP processing into for better parallelization.By default, no splitting is performed (value is 1). Splitting is enabled only for decode.
136137
- `--dynamo_specialize_float` enables specialization for float inputs by setting `specialize_float=True` in the `torch._dynamo` configuration. This option is applicable only when using `torch.compile` and can enhance performance, particularly in models utilizing FP8 quantization.
137138

138139
For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:

examples/text-generation/run_generation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,12 @@ def setup_parser(parser):
434434
type=int,
435435
help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.",
436436
)
437+
parser.add_argument(
438+
"--decode_attn_batch_split",
439+
default=1,
440+
type=int,
441+
help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for decode.",
442+
)
437443
parser.add_argument(
438444
"--regional_compile",
439445
action="store_true",

examples/text-generation/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
747747
generation_config.trust_remote_code = args.trust_remote_code
748748
generation_config.valid_sequence_lengths = None
749749
generation_config.attn_batch_split = args.attn_batch_split
750+
generation_config.decode_attn_batch_split = args.decode_attn_batch_split
750751

751752
return generation_config
752753

@@ -770,9 +771,12 @@ def exclude_hpu_graph_configs(args):
770771

771772
def initialize_model(args, logger):
772773
setup_distributed(args)
773-
if not args.world_size > 0 and args.attn_batch_split > 1:
774-
logger.warning("Disabling attention batch splitting as it's unnecessary for single-card execution")
774+
if args.world_size <= 1 and args.attn_batch_split > 1:
775+
logger.warning("Disabling attention batch splitting for prompt as it's unnecessary for single-card execution")
775776
args.attn_batch_split = 1
777+
if args.world_size <= 1 and args.decode_attn_batch_split > 1:
778+
logger.warning("Disabling attention batch splitting for decode as it's unnecessary for single-card execution")
779+
args.decode_attn_batch_split = 1
776780
if exclude_hpu_graph_configs(args):
777781
args.limit_hpu_graphs = False
778782
override_prints(args.global_rank == 0 or args.verbose_workers, logger)

optimum/habana/transformers/generation/configuration_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class GaudiGenerationConfig(GenerationConfig):
4141
Whether to use fast softmax with reduced precision if use Habana flash attention.
4242
attn_batch_split (`int`, *optional*):
4343
Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.
44+
decode_attn_batch_split (`int`, *optional*):
45+
Specify the batch size split for attention and mlp layers for decode. 1 for no split.
4446
logits_bf16 (`bool`, *optional*):
4547
Keep logits in bf16.
4648
"""
@@ -65,4 +67,5 @@ def __init__(self, **kwargs):
6567
self.use_fused_rope = kwargs.get("use_fused_rope", None)
6668
self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None)
6769
self.attn_batch_split = kwargs.get("attn_batch_split", 1)
70+
self.decode_attn_batch_split = kwargs.get("decode_attn_batch_split", 1)
6871
self.logits_bf16 = kwargs.get("logits_bf16", None)

optimum/habana/transformers/generation/utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,17 @@ def _pad_past_key_values(self, model_kwargs):
566566
# Mark step if lazy mode is enabled
567567
if lazy_mode:
568568
self.htcore_generation.mark_step()
569+
# For Non-MQA models with decode_attn_batch_split > 1, past_key_values is a list of list of list (k and v)
570+
elif not is_mqa_model and model_kwargs.get("decode_attn_batch_split", 1) > 1:
571+
for i, layer in enumerate(past_key_values): # Iterate over layers
572+
for j, split_kv_caches in enumerate(layer): # Iterate over splitted kv_cache
573+
for k, k_or_v in enumerate(split_kv_caches): # Iterate over k and v
574+
if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
575+
# tensor(batch_size/num_splits, n_heads, kv_cache_len, head_dim)
576+
past_key_values[i][j][k] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
577+
# Mark step if lazy mode is enabled
578+
if lazy_mode:
579+
self.htcore_generation.mark_step()
569580
# For Non-MQA models, the past_key_values is a list of lists (k and v)
570581
else:
571582
for i, layer in enumerate(past_key_values): # Iterate over layers
@@ -1599,9 +1610,12 @@ def generate(
15991610
# prepare for allocate kv cache
16001611
model_kwargs["reuse_cache"] = generation_config.reuse_cache
16011612

1602-
# prepare for attention batch splitting
1613+
# prepare for attention batch splitting for prompt
16031614
model_kwargs["attn_batch_split"] = generation_config.attn_batch_split
16041615

1616+
# prepare for attention batch splitting for decode
1617+
model_kwargs["decode_attn_batch_split"] = generation_config.decode_attn_batch_split
1618+
16051619
# Keep logits in bf16
16061620
model_kwargs["logits_bf16"] = kwargs.get("logits_bf16")
16071621

@@ -2934,9 +2948,13 @@ def _sample(
29342948
if "inputs_embeds" in model_inputs
29352949
else None
29362950
)
2951+
if model_kwargs["decode_attn_batch_split"] > 1:
2952+
output_past_key_values_shape = outputs.past_key_values[0][0][0].shape
2953+
else:
2954+
output_past_key_values_shape = outputs.past_key_values[0][0].shape
29372955
do_padding = (
29382956
key_to_check is not None
2939-
and outputs.past_key_values[0][0].shape[2] == model_inputs[key_to_check].shape[1]
2957+
and output_past_key_values_shape[2] == model_inputs[key_to_check].shape[1]
29402958
and generation_config.max_new_tokens > 1
29412959
)
29422960

optimum/habana/transformers/models/llama/modeling_llama.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,7 @@ def forward(
945945
cache_idx: int = None,
946946
num_virtual_tokens: int = None,
947947
attn_batch_split: int = 1,
948+
decode_attn_batch_split: int = 1,
948949
prev_layer_residual: Optional[torch.Tensor] = None,
949950
**kwargs,
950951
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
@@ -960,8 +961,13 @@ def forward(
960961
- add new arg flash_attention_causal_mask
961962
- add new arg flash_attention_fast_softmax
962963
"""
963-
if attn_batch_split > 1 and past_key_value is None:
964+
if (attn_batch_split > 1 and past_key_value is None) or (
965+
decode_attn_batch_split > 1 and past_key_value is not None
966+
):
964967
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
968+
if past_key_value is not None:
969+
attn_batch_split = decode_attn_batch_split
970+
965971
batch_size = attention_mask.size(0)
966972
base_split_size = batch_size // attn_batch_split
967973
remainder = batch_size % attn_batch_split
@@ -984,11 +990,11 @@ def forward(
984990
split_hidden_states[i] = self.post_mlp(hidden_states[i], prev_layer_residual[i])
985991

986992
residual[i] = split_hidden_states[i]
987-
split_hidden_states[i], self_attn_weights, present_key_value = self.pre_attn(
993+
split_hidden_states[i], self_attn_weights, inter_present_key_value = self.pre_attn(
988994
hidden_states=split_hidden_states[i],
989995
attention_mask=sub_attention_mask[i],
990996
position_ids=sub_position_ids[i],
991-
past_key_value=past_key_value,
997+
past_key_value=past_key_value[i] if past_key_value is not None else past_key_value,
992998
use_cache=use_cache,
993999
cache_position=cache_position,
9941000
position_embeddings=position_embeddings,
@@ -1006,10 +1012,17 @@ def forward(
10061012
)
10071013
self.self_attn.attention_all_reduce(split_hidden_states[i])
10081014
if use_cache:
1009-
split_present_key_values.append(present_key_value)
1015+
split_present_key_values.append(inter_present_key_value)
10101016

10111017
self_attn_weights = torch.cat(split_attn_weights, dim=0) if split_attn_weights else None
1012-
present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)]
1018+
if decode_attn_batch_split > 1:
1019+
# Instead of concatenating, keep them as a list of lists
1020+
# [[k1, v1], [k2, v2]]
1021+
present_key_value = split_present_key_values
1022+
else:
1023+
# Concatenate along the batch dimension to form the final present_key_value
1024+
# [k, v] where k and v have batch dimension = sum of all splits
1025+
present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)]
10131026

10141027
int_residual_splits = []
10151028
for i in range(attn_batch_split):
@@ -1054,7 +1067,9 @@ def forward(
10541067
if use_cache:
10551068
outputs += (present_key_value,)
10561069
# Store the residual splits to add them in the beginning of the next layer
1057-
if attn_batch_split > 1 and past_key_value is None:
1070+
if (attn_batch_split > 1 and past_key_value is None) or (
1071+
decode_attn_batch_split > 1 and past_key_value is not None
1072+
):
10581073
outputs += (int_residual_splits,)
10591074

10601075
return outputs
@@ -1197,6 +1212,7 @@ def forward(
11971212
lazy_mode: Optional[bool] = True,
11981213
num_virtual_tokens: int = None,
11991214
attn_batch_split: int = 1,
1215+
decode_attn_batch_split: int = 1,
12001216
**kwargs,
12011217
) -> BaseModelOutputWithPast:
12021218
"""
@@ -1246,8 +1262,10 @@ def forward(
12461262
past_seen_tokens = past_key_values[0][0][2]
12471263
else:
12481264
# HPU uses legacy cache path (use_new_cache = False)
1249-
past_seen_tokens = past_key_values[0][0].shape[2]
1250-
1265+
if decode_attn_batch_split > 1:
1266+
past_seen_tokens = past_key_values[0][0][0].shape[2]
1267+
else:
1268+
past_seen_tokens = past_key_values[0][0].shape[2]
12511269
if ignore_cache_position is False:
12521270
if cache_position is None:
12531271
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -1284,8 +1302,11 @@ def forward(
12841302
htcore.mark_step()
12851303

12861304
split_prompt = False
1287-
prev_layer_residual = None
1288-
if attn_batch_split > 1 and past_key_values is None:
1305+
if (attn_batch_split > 1 and past_key_values is None) or (
1306+
decode_attn_batch_split > 1 and past_key_values is not None
1307+
):
1308+
if past_key_values is not None:
1309+
attn_batch_split = decode_attn_batch_split
12891310
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
12901311
batch_size = hidden_states.size(0)
12911312
base_split_size = batch_size // attn_batch_split
@@ -1295,6 +1316,8 @@ def forward(
12951316
hidden_states_split = torch.split(hidden_states, split_sizes, dim=0)
12961317
split_prompt = True
12971318

1319+
prev_layer_residual = None
1320+
12981321
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
12991322
if (
13001323
lazy_mode
@@ -1306,10 +1329,12 @@ def forward(
13061329
# Calling the layer with positional arguments
13071330
# This is a workaround for an issue with DeepSpeed where
13081331
# it cannot handle keyword arguments and throws a RuntimError
1309-
use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None
1332+
past_key_value = None if past_key_values is None else past_key_values[layer_idx]
1333+
use_prev_layer_residual = (attn_batch_split > 1 and past_key_value is None) or (
1334+
decode_attn_batch_split > 1 and past_key_value is not None
1335+
)
13101336
layer_prev_layer_residual = prev_layer_residual if use_prev_layer_residual else None
13111337
layer_hidden_states = hidden_states_split if split_prompt else hidden_states
1312-
past_key_value = None if past_key_values is None else past_key_values[layer_idx]
13131338
layer_outputs = decoder_layer(
13141339
layer_hidden_states,
13151340
causal_mask,
@@ -1330,6 +1355,7 @@ def forward(
13301355
cache_idx,
13311356
num_virtual_tokens,
13321357
attn_batch_split,
1358+
decode_attn_batch_split,
13331359
layer_prev_layer_residual,
13341360
)
13351361
if use_prev_layer_residual:
@@ -1345,6 +1371,11 @@ def forward(
13451371

13461372
hidden_states = self.norm(hidden_states)
13471373

1374+
if lazy_mode and decode_attn_batch_split > 1 and torch.distributed.get_world_size() > 1:
1375+
# For synchronization, put a barrier here so that all processes
1376+
# finish computation before moving to the next step.
1377+
# Recommended to use for llama 405B model during decoding with batch split
1378+
torch.distributed.barrier()
13481379
next_cache = next_decoder_cache if use_cache else None
13491380
if not use_new_cache and isinstance(next_cache, Cache):
13501381
next_cache = next_cache.to_legacy_cache()
@@ -1409,6 +1440,7 @@ def forward(
14091440
lazy_mode: Optional[bool] = True,
14101441
num_virtual_tokens: int = None,
14111442
attn_batch_split: int = 1,
1443+
decode_attn_batch_split: int = 1,
14121444
**kwargs: Unpack[TransformersKwargs],
14131445
) -> CausalLMOutputWithPast:
14141446
if self.generation_config.use_fused_rope is False:
@@ -1436,6 +1468,7 @@ def forward(
14361468
lazy_mode=lazy_mode,
14371469
num_virtual_tokens=num_virtual_tokens,
14381470
attn_batch_split=attn_batch_split,
1471+
decode_attn_batch_split=decode_attn_batch_split,
14391472
**kwargs,
14401473
)
14411474

@@ -1563,6 +1596,7 @@ def prepare_inputs_for_generation(
15631596
"lazy_mode": kwargs.get("lazy_mode"),
15641597
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
15651598
"attn_batch_split": kwargs.get("attn_batch_split"),
1599+
"decode_attn_batch_split": kwargs.get("decode_attn_batch_split"),
15661600
}
15671601
)
15681602
return model_inputs

0 commit comments

Comments
 (0)