Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 28 additions & 37 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def forward(
# x has shape [b, s_x, d]
# y has shape [b, s_y, d]
b, s_x, _ = x.shape
s_y = y.shape[1] if y is not None else 0

# q has shape [b, s_x, num_heads * head_dim]
q = self.q_proj(x)
Expand All @@ -262,9 +263,16 @@ def forward(
if self.q_norm is not None:
q = self.q_norm(q)

def calculate_kv(y):
if y is None:
if self.kv_cache is None:
raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
k = self.kv_cache.k_cache
v = self.kv_cache.v_cache
else:
# Update k and v shape, positional embeddings, and normalization
s_y = y.shape[1]

# k has shape [b, s_y, num_kv_heads * head_dim]
# v has shape [b, s_y, num_kv_heads * head_dim]
k = self.k_proj(y)
Expand All @@ -280,37 +288,12 @@ def calculate_kv(y):
# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)
return k, v

def true_fn(y):
kv_cache = self.kv_cache.clone()
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos

def false_fn(y):
k, v = calculate_kv(y)
kv_cache = self.kv_cache.clone()
kv_cache.update(k, v)
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos

# If kv cache is None, we expect y to be provided
if self.kv_cache is None:
assert (
y is not None
), "Must provide y input or use kv_cache to enable streaming decoding"
k, v = calculate_kv(y)
else:
# Expecting the k, v returning here to be the same size of self.kv_cache
# In eager, we expect this predicate to specialize. In export, this will
# become a SymBool so it's not specialized.
k, v, cache_pos = torch.cond(
torch.isnan(y).all().item(), true_fn, false_fn, (y,)
)
# Update key-value cache
self.kv_cache.k_cache.copy_(k)
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
output = self._sdpa(q, k, v, b, s_x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: the refactor removed the mask arg, make sure to add it back

return self.output_proj(output)


Expand Down Expand Up @@ -352,17 +335,25 @@ def forward(
# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

# [bsz, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# k: [bsz, seq_len, n_kv, 1, h_d]
# v: [bsz, seq_len, n_kv, 1, h_d]
k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)
v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim)

# Expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
if self.num_heads != self.num_kv_heads:
expand_shape = (-1, -1, self.q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)
v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim)

# [bsz, s, n_h, h_d]
k = k.reshape(bsz, -1, self.num_heads, self.head_dim)
v = v.reshape(bsz, -1, self.num_heads, self.head_dim)

# [bsz, n_h, s, h_d]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

output = self._attention_fn(
q,
Expand Down
13 changes: 8 additions & 5 deletions extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ def update(
v_out = self.v_cache

if self.transpose_cache:
k_out[:, :, self.cache_pos[:seq_len]] = k_val
v_out[:, :, self.cache_pos[:seq_len]] = v_val
pos_mask = torch.arange(k_out.shape[2]) < seq_len
k_out[:, :, self.cache_pos[pos_mask]] = k_val
v_out[:, :, self.cache_pos[pos_mask]] = v_val
else:
k_out[:, self.cache_pos[:seq_len]] = k_val
v_out[:, self.cache_pos[:seq_len]] = v_val
pos_mask = torch.arange(k_out.shape[1]) < seq_len
k_out[:, self.cache_pos[pos_mask], :] = k_val
v_out[:, self.cache_pos[pos_mask], :] = v_val

# forward cache_pos seq_len positions along
# cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
Expand All @@ -124,7 +126,8 @@ def update(
# this allows us to track the current position in the cache
# after the last update in a compile-friendly way without any dynamism
# e.g. relying on an int size tracker, or re-creating cache_pos every time
self.cache_pos.add_(seq_len)
mask = (seq_len > 0) * 1
self.cache_pos.add_(seq_len * mask)

return k_out, v_out

Expand Down
15 changes: 8 additions & 7 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,28 +219,29 @@ def test_attention_torch_cond_eager(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

# mask
mask = self.causal_mask[self.input_pos, :]
# First run
# First run, for the value of the second parameter in the forward, it doesn't matter
# whether it is the same as the first (self attention) or if it is different (cross
# attention).
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
)
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
)

self.assertTrue(torch.allclose(et_res, tt_res))

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
# Second run test kv cache read. Input pos is [10, 11, ..., 19].
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = self.et_mha(
self.x, empty_y, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
) # Cross attention with no y input, ET uses a tensor of empty values.
tt_res = self.tt_mha(
self.x, None, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
) # Cross attention with no y input, TorchTune uses None.

assert_close(et_res, tt_res)
Loading