Skip to content

Commit 04b3d92

Browse files
authored
[llama-mm] Add torch.cond to replace if condition in MHA (#6869)
* [llama-mm] Add torch.cond to replace if condition in MHA Summary: In torchtune's MultiHeadAttention we have this logic: If `y` is not None, calculate the values of `k` and `v` from y and update the KVCache. Otherwise (if `y` is None), retrieve the value of `k` and `v` from KVCache. This logic is not able to be handled by export world. Here I'm proposing a rewrite: If `y` does not have all values equal to nan (not a number), calculate the values of `k` and `v` from `y` and update the KVCache. Otherwise (if all of the values of `y` are nan), retrieve the value of `k` and `v` from KVCache. This rewrite allows the module to satisfy the requirement of `torch.cond` and avoid specialization: * The operands to `torch.cond` should have the same shape for the true branch and the false branch. This means we will have to change this logic in torchtune: ``` if encoder_input is not None: encoder_embed = self.encoder(**encoder_input) output = self.decoder( tokens=tokens, mask=mask, encoder_input=encoder_embed, encoder_mask=encoder_mask, input_pos=input_pos, ) ``` To be: ``` if encoder_input is not None: encoder_embed = self.encoder(**encoder_input) else: encoder_embed = torch.full_like(encoder_input, torch.nan) output = self.decoder( tokens=tokens, mask=mask, encoder_input=encoder_embed, encoder_mask=encoder_mask, input_pos=input_pos, ) ``` Test Plan: Rely on unit tests Reviewers: Subscribers: Tasks: Tags: * Add test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 7b76f0f commit 04b3d92

File tree

3 files changed

+93
-15
lines changed

3 files changed

+93
-15
lines changed

extension/llm/modules/attention.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ def forward(
246246
# x has shape [b, s_x, d]
247247
# y has shape [b, s_y, d]
248248
b, s_x, _ = x.shape
249-
s_y = y.shape[1] if y is not None else 0
250249

251250
# q has shape [b, s_x, num_heads * head_dim]
252251
q = self.q_proj(x)
@@ -263,16 +262,9 @@ def forward(
263262
if self.q_norm is not None:
264263
q = self.q_norm(q)
265264

266-
if y is None:
267-
if self.kv_cache is None:
268-
raise ValueError(
269-
"Must provide y input or use kv_cache to enable streaming decoding"
270-
)
271-
k = self.kv_cache.k_cache
272-
v = self.kv_cache.v_cache
273-
else:
265+
def calculate_kv(y):
274266
# Update k and v shape, positional embeddings, and normalization
275-
267+
s_y = y.shape[1]
276268
# k has shape [b, s_y, num_kv_heads * head_dim]
277269
# v has shape [b, s_y, num_kv_heads * head_dim]
278270
k = self.k_proj(y)
@@ -288,12 +280,37 @@ def forward(
288280
# Normalize k
289281
if self.k_norm is not None:
290282
k = self.k_norm(k)
283+
return k, v
284+
285+
def true_fn(y):
286+
kv_cache = self.kv_cache.clone()
287+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
288+
289+
def false_fn(y):
290+
k, v = calculate_kv(y)
291+
kv_cache = self.kv_cache.clone()
292+
kv_cache.update(k, v)
293+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
291294

295+
# If kv cache is None, we expect y to be provided
296+
if self.kv_cache is None:
297+
assert (
298+
y is not None
299+
), "Must provide y input or use kv_cache to enable streaming decoding"
300+
k, v = calculate_kv(y)
301+
else:
302+
# Expecting the k, v returning here to be the same size of self.kv_cache
303+
# In eager, we expect this predicate to specialize. In export, this will
304+
# become a SymBool so it's not specialized.
305+
k, v, cache_pos = torch.cond(
306+
torch.isnan(y).all().item(), true_fn, false_fn, (y,)
307+
)
292308
# Update key-value cache
293-
if self.kv_cache is not None and self.cache_enabled:
294-
k, v = self.kv_cache.update(k, v)
309+
self.kv_cache.k_cache.copy_(k)
310+
self.kv_cache.v_cache.copy_(v)
311+
self.kv_cache.cache_pos.copy_(cache_pos)
295312

296-
output = self._sdpa(q, k, v, b, s_x)
313+
output = self._sdpa(q, k, v, b, s_x, mask=mask)
297314
return self.output_proj(output)
298315

299316

extension/llm/modules/kv_cache.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,22 @@ def update(
127127
self.cache_pos.add_(seq_len)
128128

129129
return k_out, v_out
130+
131+
def clone(self) -> "KVCache":
132+
"""Create a clone of the KVCache."""
133+
if self.transpose_cache:
134+
num_kv_heads = self.k_cache.shape[1]
135+
else:
136+
num_kv_heads = self.k_cache.shape[2]
137+
clone = KVCache(
138+
batch_size=self.batch_size,
139+
max_seq_len=self.max_seq_len,
140+
num_kv_heads=num_kv_heads,
141+
head_dim=self.k_cache.shape[3],
142+
dtype=self.k_cache.dtype,
143+
transpose_cache=self.transpose_cache,
144+
)
145+
clone.k_cache.copy_(self.k_cache)
146+
clone.v_cache.copy_(self.v_cache)
147+
clone.cache_pos.copy_(self.cache_pos)
148+
return clone

extension/llm/modules/test/test_attention.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def setUp(self):
2727
torch.manual_seed(0)
2828
# Constants
2929
self.embed_dim = 2048
30-
self.num_heads = 32
30+
self.num_heads = 8
3131
self.num_kv_heads = 8
3232
self.head_dim = 64
3333
self.max_seq_len = 128
@@ -41,10 +41,14 @@ def setUp(self):
4141
self.k_proj = torch.nn.Linear(
4242
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
4343
)
44+
self.k_proj.weight.requires_grad = False
4445
self.v_proj = torch.nn.Linear(
4546
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
4647
)
47-
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
48+
self.v_proj.weight.requires_grad = False
49+
self.output_proj = torch.nn.Linear(
50+
self.num_heads * self.head_dim, self.embed_dim, bias=False
51+
)
4852
self.pos_embeddings = Llama3ScaledRoPE(
4953
dim=self.head_dim,
5054
max_seq_len=self.max_seq_len,
@@ -90,6 +94,12 @@ def setUp(self):
9094
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
9195
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
9296
)
97+
self.causal_mask = torch.tril(
98+
torch.ones(
99+
size=(self.max_seq_len, self.max_seq_len),
100+
dtype=torch.bool,
101+
)
102+
)
93103

94104
def test_attention_eager(self):
95105
et_res = self.et_mha(self.x, self.x) # Self attention.
@@ -195,3 +205,35 @@ def test_attention_executorch(self):
195205
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
196206

197207
assert_close(et_res[0], tt_res)
208+
209+
def test_attention_torch_cond_eager(self):
210+
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
211+
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
212+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
213+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
214+
215+
# mask
216+
mask = self.causal_mask[self.input_pos, :]
217+
# First run
218+
et_res = self.et_mha(
219+
self.x, self.x, mask=mask, input_pos=self.input_pos
220+
) # Self attention with input pos.
221+
tt_res = self.tt_mha(
222+
self.x, self.x, mask=mask, input_pos=self.input_pos
223+
) # Self attention with input pos.
224+
225+
self.assertTrue(torch.allclose(et_res, tt_res))
226+
227+
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
228+
next_input_pos = torch.arange(10, 20).unsqueeze(0)
229+
230+
empty_y = torch.full_like(self.x, torch.nan)
231+
mask = self.causal_mask[next_input_pos, :]
232+
et_res = self.et_mha(
233+
self.x, empty_y, mask=mask, input_pos=next_input_pos
234+
) # Self attention with input pos.
235+
tt_res = self.tt_mha(
236+
self.x, None, mask=mask, input_pos=next_input_pos
237+
) # Self attention with input pos.
238+
239+
assert_close(et_res, tt_res)

0 commit comments

Comments
 (0)