Skip to content

Commit df6e87a

Browse files
committed
Add MLA reference script
1 parent dfadad7 commit df6e87a

File tree

2 files changed

+314
-22
lines changed

2 files changed

+314
-22
lines changed

ch04/05_mla/README.md

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,79 @@
11
# Multi-Head Latent Attention (MLA)
22

3-
This bonus material illustrates the memory savings when using Multi-Head Latent Attention (MLA) over regular Multi-Head Attention (MHA).
3+
This bonus material illustrates the memory savings when using Multi-Head Latent Attention
4+
(MLA) over regular Multi-Head Attention (MHA).
45

56
 
67
## Introduction
78

8-
In [../04_gqa](../04_gqa), we discussed Grouped-Query Attention (GQA) as a computational-efficiency workaround for MHA. And ablation studies (such as those in the[ original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2 paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in terms of LLM modeling performance.
9+
In [../04_gqa](../04_gqa), we discussed Grouped-Query Attention (GQA) as a
10+
computational-efficiency workaround for MHA. And ablation studies (such as those in the[
11+
original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2
12+
paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in
13+
terms of LLM modeling performance.
914

10-
Now, Multi-Head Latent Attention (MLA), which is used in [DeepSeek V2, V3, and R1](https://arxiv.org/abs/2412.19437), offers a different memory-saving strategy that also pairs particularly well with KV caching. Instead of sharing key and value heads like GQA, MLA compresses the key and value tensors into a lower-dimensional space before storing them in the KV cache.
15+
Now, Multi-Head Latent Attention (MLA), which is used in [DeepSeek V2, V3, and
16+
R1](https://arxiv.org/abs/2412.19437), offers a different memory-saving strategy that
17+
also pairs particularly well with KV caching. Instead of sharing key and value heads like
18+
GQA, MLA compresses the key and value tensors into a lower-dimensional space before
19+
storing them in the KV cache.
1120

12-
At inference time, these compressed tensors are projected back to their original size before being used, as shown in the figure below. This adds an extra matrix multiplication but reduces memory usage.
21+
At inference time, these compressed tensors are projected back to their original size
22+
before being used, as shown in the figure below. This adds an extra matrix multiplication
23+
but reduces memory usage.
1324

1425
 
1526

1627
![MLA](https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/1.webp)
1728

1829
 
1930

20-
(As a side note, the queries are also compressed, but only during training, not inference.)
31+
(As a side note, the queries are also compressed, but only during training, not
32+
inference.)
2133

22-
By the way, as mentioned earlier, MLA is not new in DeepSeek V3, as its [DeepSeek V2 predecessor](https://arxiv.org/abs/2405.04434) also used (and even introduced) it. Also, the V2 paper contains a few interesting ablation studies that may explain why the DeepSeek team chose MLA over GQA (see the figure below).
34+
By the way, as mentioned earlier, MLA is not new in DeepSeek V3, as its [DeepSeek V2
35+
predecessor](https://arxiv.org/abs/2405.04434) also used (and even introduced) it. Also,
36+
the V2 paper contains a few interesting ablation studies that may explain why the
37+
DeepSeek team chose MLA over GQA (see the figure below).
2338

2439
 
2540

26-
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/2.webp" alt="GQA" width="500px" />
41+
<img
42+
src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/2.webp"
43+
alt="GQA" width="500px" />
2744

2845
&nbsp;
2946

30-
As shown in the figure above, GQA appears to perform worse than MHA, whereas MLA offers better modeling performance than MHA, which is likely why the DeepSeek team chose MLA over GQA. (It would have been interesting to see the "KV Cache per Token" savings comparison between MLA and GQA as well!)
47+
As shown in the figure above, GQA appears to perform worse than MHA, whereas MLA offers
48+
better modeling performance than MHA, which is likely why the DeepSeek team chose MLA
49+
over GQA. (It would have been interesting to see the "KV Cache per Token" savings
50+
comparison between MLA and GQA as well!)
3151

32-
To summarize this section, before we move on to the next architecture component, MLA is a clever trick to reduce KV cache memory use while even slightly outperforming MHA in terms of modeling performance.
52+
To summarize this section, before we move on to the next architecture component, MLA is a
53+
clever trick to reduce KV cache memory use while even slightly outperforming MHA in terms
54+
of modeling performance.
3355

3456
&nbsp;
3557
## MLA Memory Savings
3658

37-
The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
59+
The memory savings are mostly reflected in the KV storage. We can compute the KV storage
60+
size with the following formula:
3861

3962
bytes ≈ batch_size × seqlen × n_layers × latent_dim × bytes_per_elem
4063

4164
In contrast, MHA KV cache memory is computed as follows:
4265

4366
bytes ≈ batch_size × seqlen × n_layers × embed_dim × 2 (K,V) × bytes_per_elem
4467

45-
This means, in MLA, we reduce "embed_dim × 2 (K,V)" to "latent_dim", since we only stored the compressed latent representation instead of the full key and value vectors as shown in the earlier figure above.
68+
This means, in MLA, we reduce "embed_dim × 2 (K,V)" to "latent_dim", since we only stored
69+
the compressed latent representation instead of the full key and value vectors as shown
70+
in the earlier figure above.
4671

4772

4873

49-
You can use the [memory_estimator_mla.py](memory_estimator_mla.py) script in this folder to apply this for different model configs to see how much memory you can save by using MLA over MHA:
74+
You can use the [memory_estimator_mla.py](memory_estimator_mla.py) script in this folder
75+
to apply this for different model configs to see how much memory you can save by using
76+
MLA over MHA:
5077

5178
```bash
5279
➜ uv run memory_estimator_mla.py \
@@ -80,13 +107,19 @@ Ratio (MHA / MLA) : 4.03x
80107
Savings (MLA vs MHA): 75.19%
81108
```
82109

83-
Note that the compression above (`--emb_dim 2048 -> latent_dim 1024`) to achieve a similar saving as for GQA. In practice, the compression is a hyperparameter that needs to be carefully investigated, as choosing `latent_dim` to be too small can have negative impact on the modeling performance (similar to choosing too many `n_kv_groups` in GQA).
110+
Note that the compression above (`--emb_dim 2048 -> latent_dim 1024`) to achieve a
111+
similar saving as for GQA. In practice, the compression is a hyperparameter that needs to
112+
be carefully investigated, as choosing `latent_dim` to be too small can have negative
113+
impact on the modeling performance (similar to choosing too many `n_kv_groups` in GQA).
84114

85-
The savings when using MLA over MHA are further shown in the plot below for different `latent_dim` values as a function of the context length:
115+
The savings when using MLA over MHA are further shown in the plot below for different
116+
`latent_dim` values as a function of the context length:
86117

87118
&nbsp;
88119

89-
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/3.webp?2" alt="GQA" width="500px" />
120+
<img
121+
src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/mla-memory/3.webp?2"
122+
alt="GQA" width="500px" />
90123

91124
&nbsp;
92125

@@ -97,15 +130,23 @@ You can reproduce the plot via `uv run plot_memory_estimates_mla.py`.
97130
&nbsp;
98131
## MLA Code Examples
99132

100-
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_kv_mla.py) scripts in this folder provide hands-on examples for comparing the MHA and MLA memory usage in the context of a GPT model implementation.
133+
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_kv_mla.py)
134+
scripts in this folder provide hands-on examples for comparing the MHA and MLA memory
135+
usage in the context of a GPT model implementation.
101136

102-
Here, the MLA code is inspired by the [https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla) implementation.
137+
Here, the MLA code is inspired by the
138+
[https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla)
139+
implementation.
103140

104-
Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
141+
Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity,
142+
I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
105143

106-
Also note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
144+
Also note that the model is not trained and thus generates nonsensical text. However, you
145+
can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train
146+
it.
107147

108-
Lastly, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache) so the memory savings are more pronounced.
148+
Lastly, this implementation uses the KV cache explained in [another bonus
149+
section](../03_kv-cache) so the memory savings are more pronounced.
109150

110151
```bash
111152
uv run gpt_with_kv_mha.py \
@@ -138,5 +179,8 @@ Max memory allocated: 0.68 GB
138179

139180
The reason why we are not seeing such a big saving as in the plots above is 2-fold:
140181

141-
1. I use a smaller configuration to have the model finish the generation in a reasonable time.
142-
2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis).
182+
1. I use a smaller configuration to have the model finish the generation in a reasonable
183+
time.
184+
2. More importantly, we are looking at the whole model here, not just the attention
185+
mechanism; the fully-connected layers in the model take up most of the memory (but
186+
this is a topic for a separate analysis).

0 commit comments

Comments
 (0)