Skip to content

Commit 4c138ef

Browse files
committed
[llama-mm] Add torch.cond to replace if condition in MHA
Summary: In torchtune's MultiHeadAttention we have this logic: If `y` is not None, calculate the values of `k` and `v` from y and update the KVCache. Otherwise (if `y` is None), retrieve the value of `k` and `v` from KVCache. This logic is not able to be handled by export world. Here I'm proposing a rewrite: If `y` does not have all values equal to nan (not a number), calculate the values of `k` and `v` from `y` and update the KVCache. Otherwise (if all of the values of `y` are nan), retrieve the value of `k` and `v` from KVCache. This rewrite allows the module to satisfy the requirement of `torch.cond` and avoid specialization: * The operands to `torch.cond` should have the same shape for the true branch and the false branch. This means we will have to change this logic in torchtune: ``` if encoder_input is not None: encoder_embed = self.encoder(**encoder_input) output = self.decoder( tokens=tokens, mask=mask, encoder_input=encoder_embed, encoder_mask=encoder_mask, input_pos=input_pos, ) ``` To be: ``` if encoder_input is not None: encoder_embed = self.encoder(**encoder_input) else: encoder_embed = torch.full_like(encoder_input, torch.nan) output = self.decoder( tokens=tokens, mask=mask, encoder_input=encoder_embed, encoder_mask=encoder_mask, input_pos=input_pos, ) ``` Test Plan: Rely on unit tests Reviewers: Subscribers: Tasks: Tags:
1 parent 71612a6 commit 4c138ef

File tree

3 files changed

+51
-12
lines changed

3 files changed

+51
-12
lines changed

extension/llm/modules/attention.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ def forward(
246246
# x has shape [b, s_x, d]
247247
# y has shape [b, s_y, d]
248248
b, s_x, _ = x.shape
249-
s_y = y.shape[1] if y is not None else 0
250249

251250
# q has shape [b, s_x, num_heads * head_dim]
252251
q = self.q_proj(x)
@@ -263,16 +262,9 @@ def forward(
263262
if self.q_norm is not None:
264263
q = self.q_norm(q)
265264

266-
if y is None:
267-
if self.kv_cache is None:
268-
raise ValueError(
269-
"Must provide y input or use kv_cache to enable streaming decoding"
270-
)
271-
k = self.kv_cache.k_cache
272-
v = self.kv_cache.v_cache
273-
else:
265+
def calculate_kv(y):
274266
# Update k and v shape, positional embeddings, and normalization
275-
267+
s_y = y.shape[1]
276268
# k has shape [b, s_y, num_kv_heads * head_dim]
277269
# v has shape [b, s_y, num_kv_heads * head_dim]
278270
k = self.k_proj(y)
@@ -288,10 +280,35 @@ def forward(
288280
# Normalize k
289281
if self.k_norm is not None:
290282
k = self.k_norm(k)
283+
return k, v
284+
285+
def true_fn(y):
286+
kv_cache = self.kv_cache.clone()
287+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
288+
289+
def false_fn(y):
290+
k, v = calculate_kv(y)
291+
kv_cache = self.kv_cache.clone()
292+
kv_cache.update(k, v)
293+
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
291294

295+
# If kv cache is None, we expect y to be provided
296+
if self.kv_cache is None:
297+
assert (
298+
y is not None
299+
), "Must provide y input or use kv_cache to enable streaming decoding"
300+
k, v = calculate_kv(y)
301+
else:
302+
# Expecting the k, v returning here to be the same size of self.kv_cache
303+
# In eager, we expect this predicate to specialize. In export, this will
304+
# become a SymBool so it's not specialized.
305+
k, v, cache_pos = torch.cond(
306+
torch.isnan(y).all().item(), true_fn, false_fn, (y,)
307+
)
292308
# Update key-value cache
293-
if self.kv_cache is not None and self.cache_enabled:
294-
k, v = self.kv_cache.update(k, v)
309+
self.kv_cache.k_cache.copy_(k)
310+
self.kv_cache.v_cache.copy_(v)
311+
self.kv_cache.cache_pos.copy_(cache_pos)
295312

296313
output = self._sdpa(q, k, v, b, s_x)
297314
return self.output_proj(output)

extension/llm/modules/kv_cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,23 @@ def update(
127127
self.cache_pos.add_(seq_len)
128128

129129
return k_out, v_out
130+
131+
def clone(self) -> "KVCache":
132+
"""Create a clone of the KVCache."""
133+
if self.transpose_cache:
134+
max_seq_len = self.k_cache.shape[1]
135+
num_kv_heads = self.k_cache.shape[2]
136+
else:
137+
max_seq_len = self.k_cache.shape[2]
138+
num_kv_heads = self.k_cache.shape[1]
139+
clone = KVCache(
140+
batch_size=self.batch_size,
141+
max_seq_len=max_seq_len,
142+
num_kv_heads=num_kv_heads,
143+
head_dim=self.k_cache.shape[3],
144+
dtype=self.k_cache.dtype,
145+
)
146+
clone.k_cache.copy_(self.k_cache)
147+
clone.v_cache.copy_(self.v_cache)
148+
clone.cache_pos.copy_(self.cache_pos)
149+
return clone

extension/llm/modules/test/test_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def setUp(self):
4141
self.k_proj = torch.nn.Linear(
4242
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
4343
)
44+
self.k_proj.weight.requires_grad = False
4445
self.v_proj = torch.nn.Linear(
4546
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
4647
)
48+
self.v_proj.weight.requires_grad = False
4749
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
4850
self.pos_embeddings = Llama3ScaledRoPE(
4951
dim=self.head_dim,

0 commit comments

Comments
 (0)