You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add batch splitting in attention layer for decode to hide NIC latency (#2334)
* Add batch splitting in attention layer for decode to hide NIC latency
* Update modeling_llama.py
* Update utils.py
* Update modeling_llama.py
* Fix code style issues: typo fix, PEP 8 formatting, and indentation
- Fix typo: kv_cahe -> kv_cache in comment
- Remove extra spaces before colons to follow PEP 8
- Fix indentation in examples/text-generation/utils.py
- Apply automatic code formatting with ruff
---------
Co-authored-by: Jay Thakur <jaythaku@habana.ai>
Copy file name to clipboardExpand all lines: examples/text-generation/README.md
+1Lines changed: 1 addition & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -133,6 +133,7 @@ Here are a few settings you may be interested in:
133
133
-`--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
134
134
-`--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
135
135
-`--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits. This feature is thoroughly tested with Llama 2 70B but may be useful for other models as well.
136
+
-`--decode_attn_batch_split` specifies the number of smaller batches to split the attention and MLP processing into for better parallelization.By default, no splitting is performed (value is 1). Splitting is enabled only for decode.
136
137
-`--dynamo_specialize_float` enables specialization for float inputs by setting `specialize_float=True` in the `torch._dynamo` configuration. This option is applicable only when using `torch.compile` and can enhance performance, particularly in models utilizing FP8 quantization.
137
138
138
139
For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:
0 commit comments