Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def forward(
# x has shape [b, s_x, d]
# y has shape [b, s_y, d]
b, s_x, _ = x.shape
s_y = y.shape[1] if y is not None else 0

# q has shape [b, s_x, num_heads * head_dim]
q = self.q_proj(x)
Expand All @@ -263,16 +262,9 @@ def forward(
if self.q_norm is not None:
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
k = self.kv_cache.k_cache
v = self.kv_cache.v_cache
else:
def calculate_kv(y):
# Update k and v shape, positional embeddings, and normalization

s_y = y.shape[1]
# k has shape [b, s_y, num_kv_heads * head_dim]
# v has shape [b, s_y, num_kv_heads * head_dim]
k = self.k_proj(y)
Expand All @@ -288,12 +280,37 @@ def forward(
# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)
return k, v

def true_fn(y):
kv_cache = self.kv_cache.clone()
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos

def false_fn(y):
k, v = calculate_kv(y)
kv_cache = self.kv_cache.clone()
kv_cache.update(k, v)
return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos

# If kv cache is None, we expect y to be provided
if self.kv_cache is None:
assert (
y is not None
), "Must provide y input or use kv_cache to enable streaming decoding"
k, v = calculate_kv(y)
else:
# Expecting the k, v returning here to be the same size of self.kv_cache
# In eager, we expect this predicate to specialize. In export, this will
# become a SymBool so it's not specialized.
k, v, cache_pos = torch.cond(
torch.isnan(y).all().item(), true_fn, false_fn, (y,)
)
# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)
self.kv_cache.k_cache.copy_(k)
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x)
output = self._sdpa(q, k, v, b, s_x, mask=mask)
return self.output_proj(output)


Expand Down
19 changes: 19 additions & 0 deletions extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,22 @@ def update(
self.cache_pos.add_(seq_len)

return k_out, v_out

def clone(self) -> "KVCache":
"""Create a clone of the KVCache."""
if self.transpose_cache:
num_kv_heads = self.k_cache.shape[1]
else:
num_kv_heads = self.k_cache.shape[2]
clone = KVCache(
batch_size=self.batch_size,
max_seq_len=self.max_seq_len,
num_kv_heads=num_kv_heads,
head_dim=self.k_cache.shape[3],
dtype=self.k_cache.dtype,
transpose_cache=self.transpose_cache,
)
clone.k_cache.copy_(self.k_cache)
clone.v_cache.copy_(self.v_cache)
clone.cache_pos.copy_(self.cache_pos)
return clone
46 changes: 44 additions & 2 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
torch.manual_seed(0)
# Constants
self.embed_dim = 2048
self.num_heads = 32
self.num_heads = 8
self.num_kv_heads = 8
self.head_dim = 64
self.max_seq_len = 128
Expand All @@ -41,10 +41,14 @@ def setUp(self):
self.k_proj = torch.nn.Linear(
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
)
self.k_proj.weight.requires_grad = False
self.v_proj = torch.nn.Linear(
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
)
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj.weight.requires_grad = False
self.output_proj = torch.nn.Linear(
self.num_heads * self.head_dim, self.embed_dim, bias=False
)
self.pos_embeddings = Llama3ScaledRoPE(
dim=self.head_dim,
max_seq_len=self.max_seq_len,
Expand Down Expand Up @@ -90,6 +94,12 @@ def setUp(self):
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
)
self.causal_mask = torch.tril(
torch.ones(
size=(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
)
)

def test_attention_eager(self):
et_res = self.et_mha(self.x, self.x) # Self attention.
Expand Down Expand Up @@ -195,3 +205,35 @@ def test_attention_executorch(self):
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

assert_close(et_res[0], tt_res)

def test_attention_torch_cond_eager(self):
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

# mask
mask = self.causal_mask[self.input_pos, :]
# First run
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = self.et_mha(
self.x, empty_y, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, None, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.

assert_close(et_res, tt_res)
Loading