You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ch04/05_mla/README.md
+66-22Lines changed: 66 additions & 22 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,52 +1,79 @@
1
1
# Multi-Head Latent Attention (MLA)
2
2
3
-
This bonus material illustrates the memory savings when using Multi-Head Latent Attention (MLA) over regular Multi-Head Attention (MHA).
3
+
This bonus material illustrates the memory savings when using Multi-Head Latent Attention
4
+
(MLA) over regular Multi-Head Attention (MHA).
4
5
5
6
6
7
## Introduction
7
8
8
-
In [../04_gqa](../04_gqa), we discussed Grouped-Query Attention (GQA) as a computational-efficiency workaround for MHA. And ablation studies (such as those in the[ original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2 paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in terms of LLM modeling performance.
9
+
In [../04_gqa](../04_gqa), we discussed Grouped-Query Attention (GQA) as a
10
+
computational-efficiency workaround for MHA. And ablation studies (such as those in the[
11
+
original GQA paper](https://arxiv.org/abs/2305.13245) and the [Llama 2
12
+
paper](https://arxiv.org/abs/2307.09288)) show it performs comparably to standard MHA in
13
+
terms of LLM modeling performance.
9
14
10
-
Now, Multi-Head Latent Attention (MLA), which is used in [DeepSeek V2, V3, and R1](https://arxiv.org/abs/2412.19437), offers a different memory-saving strategy that also pairs particularly well with KV caching. Instead of sharing key and value heads like GQA, MLA compresses the key and value tensors into a lower-dimensional space before storing them in the KV cache.
15
+
Now, Multi-Head Latent Attention (MLA), which is used in [DeepSeek V2, V3, and
16
+
R1](https://arxiv.org/abs/2412.19437), offers a different memory-saving strategy that
17
+
also pairs particularly well with KV caching. Instead of sharing key and value heads like
18
+
GQA, MLA compresses the key and value tensors into a lower-dimensional space before
19
+
storing them in the KV cache.
11
20
12
-
At inference time, these compressed tensors are projected back to their original size before being used, as shown in the figure below. This adds an extra matrix multiplication but reduces memory usage.
21
+
At inference time, these compressed tensors are projected back to their original size
22
+
before being used, as shown in the figure below. This adds an extra matrix multiplication
(As a side note, the queries are also compressed, but only during training, not inference.)
31
+
(As a side note, the queries are also compressed, but only during training, not
32
+
inference.)
21
33
22
-
By the way, as mentioned earlier, MLA is not new in DeepSeek V3, as its [DeepSeek V2 predecessor](https://arxiv.org/abs/2405.04434) also used (and even introduced) it. Also, the V2 paper contains a few interesting ablation studies that may explain why the DeepSeek team chose MLA over GQA (see the figure below).
34
+
By the way, as mentioned earlier, MLA is not new in DeepSeek V3, as its [DeepSeek V2
35
+
predecessor](https://arxiv.org/abs/2405.04434) also used (and even introduced) it. Also,
36
+
the V2 paper contains a few interesting ablation studies that may explain why the
37
+
DeepSeek team chose MLA over GQA (see the figure below).
As shown in the figure above, GQA appears to perform worse than MHA, whereas MLA offers better modeling performance than MHA, which is likely why the DeepSeek team chose MLA over GQA. (It would have been interesting to see the "KV Cache per Token" savings comparison between MLA and GQA as well!)
47
+
As shown in the figure above, GQA appears to perform worse than MHA, whereas MLA offers
48
+
better modeling performance than MHA, which is likely why the DeepSeek team chose MLA
49
+
over GQA. (It would have been interesting to see the "KV Cache per Token" savings
50
+
comparison between MLA and GQA as well!)
31
51
32
-
To summarize this section, before we move on to the next architecture component, MLA is a clever trick to reduce KV cache memory use while even slightly outperforming MHA in terms of modeling performance.
52
+
To summarize this section, before we move on to the next architecture component, MLA is a
53
+
clever trick to reduce KV cache memory use while even slightly outperforming MHA in terms
54
+
of modeling performance.
33
55
34
56
35
57
## MLA Memory Savings
36
58
37
-
The memory savings are mostly reflected in the KV storage. We can compute the KV storage size with the following formula:
59
+
The memory savings are mostly reflected in the KV storage. We can compute the KV storage
This means, in MLA, we reduce "embed_dim × 2 (K,V)" to "latent_dim", since we only stored the compressed latent representation instead of the full key and value vectors as shown in the earlier figure above.
68
+
This means, in MLA, we reduce "embed_dim × 2 (K,V)" to "latent_dim", since we only stored
69
+
the compressed latent representation instead of the full key and value vectors as shown
70
+
in the earlier figure above.
46
71
47
72
48
73
49
-
You can use the [memory_estimator_mla.py](memory_estimator_mla.py) script in this folder to apply this for different model configs to see how much memory you can save by using MLA over MHA:
74
+
You can use the [memory_estimator_mla.py](memory_estimator_mla.py) script in this folder
75
+
to apply this for different model configs to see how much memory you can save by using
76
+
MLA over MHA:
50
77
51
78
```bash
52
79
➜ uv run memory_estimator_mla.py \
@@ -80,13 +107,19 @@ Ratio (MHA / MLA) : 4.03x
80
107
Savings (MLA vs MHA): 75.19%
81
108
```
82
109
83
-
Note that the compression above (`--emb_dim 2048 -> latent_dim 1024`) to achieve a similar saving as for GQA. In practice, the compression is a hyperparameter that needs to be carefully investigated, as choosing `latent_dim` to be too small can have negative impact on the modeling performance (similar to choosing too many `n_kv_groups` in GQA).
110
+
Note that the compression above (`--emb_dim 2048 -> latent_dim 1024`) to achieve a
111
+
similar saving as for GQA. In practice, the compression is a hyperparameter that needs to
112
+
be carefully investigated, as choosing `latent_dim` to be too small can have negative
113
+
impact on the modeling performance (similar to choosing too many `n_kv_groups` in GQA).
84
114
85
-
The savings when using MLA over MHA are further shown in the plot below for different `latent_dim` values as a function of the context length:
115
+
The savings when using MLA over MHA are further shown in the plot below for different
116
+
`latent_dim` values as a function of the context length:
@@ -97,15 +130,23 @@ You can reproduce the plot via `uv run plot_memory_estimates_mla.py`.
97
130
98
131
## MLA Code Examples
99
132
100
-
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_kv_mla.py) scripts in this folder provide hands-on examples for comparing the MHA and MLA memory usage in the context of a GPT model implementation.
133
+
The [gpt_with_kv_mha.py](gpt_with_kv_mha.py) and [gpt_with_kv_mla.py](gpt_with_kv_mla.py)
134
+
scripts in this folder provide hands-on examples for comparing the MHA and MLA memory
135
+
usage in the context of a GPT model implementation.
101
136
102
-
Here, the MLA code is inspired by the [https://huggingface.co/bird-of-paradise/deepseek-mla](https://huggingface.co/bird-of-paradise/deepseek-mla) implementation.
Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity, I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
141
+
Note that MLA can also be used in combination with [GQA](../04_gqa), but for simplicity,
142
+
I this is not done here. (Currently, I am also not aware of a prominent LLM doing this.)
105
143
106
-
Also note that the model is not trained and thus generates nonsensical text. However, you can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train it.
144
+
Also note that the model is not trained and thus generates nonsensical text. However, you
145
+
can use it as a drop-in replacement for the standard GPT model in chapters 5-7 and train
146
+
it.
107
147
108
-
Lastly, this implementation uses the KV cache explained in [another bonus section](../03_kv-cache) so the memory savings are more pronounced.
148
+
Lastly, this implementation uses the KV cache explained in [another bonus
149
+
section](../03_kv-cache) so the memory savings are more pronounced.
109
150
110
151
```bash
111
152
uv run gpt_with_kv_mha.py \
@@ -138,5 +179,8 @@ Max memory allocated: 0.68 GB
138
179
139
180
The reason why we are not seeing such a big saving as in the plots above is 2-fold:
140
181
141
-
1. I use a smaller configuration to have the model finish the generation in a reasonable time.
142
-
2. More importantly, we are looking at the whole model here, not just the attention mechanism; the fully-connected layers in the model take up most of the memory (but this is a topic for a separate analysis).
182
+
1. I use a smaller configuration to have the model finish the generation in a reasonable
183
+
time.
184
+
2. More importantly, we are looking at the whole model here, not just the attention
185
+
mechanism; the fully-connected layers in the model take up most of the memory (but
0 commit comments