From 220aec8514b1773334fb6fc886a3429a8253ae6d Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Mon, 6 Jan 2025 20:27:46 -0800 Subject: [PATCH] Use mask instead of cond for attention conditional logic --- extension/llm/modules/attention.py | 65 +++++++++----------- extension/llm/modules/kv_cache.py | 13 ++-- extension/llm/modules/test/test_attention.py | 15 ++--- 3 files changed, 44 insertions(+), 49 deletions(-) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b42..74e14076b37 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -246,6 +246,7 @@ def forward( # x has shape [b, s_x, d] # y has shape [b, s_y, d] b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) @@ -262,9 +263,16 @@ def forward( if self.q_norm is not None: q = self.q_norm(q) - def calculate_kv(y): + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: # Update k and v shape, positional embeddings, and normalization - s_y = y.shape[1] + # k has shape [b, s_y, num_kv_heads * head_dim] # v has shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) @@ -280,37 +288,12 @@ def calculate_kv(y): # Normalize k if self.k_norm is not None: k = self.k_norm(k) - return k, v - - def true_fn(y): - kv_cache = self.kv_cache.clone() - return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos - def false_fn(y): - k, v = calculate_kv(y) - kv_cache = self.kv_cache.clone() - kv_cache.update(k, v) - return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos - - # If kv cache is None, we expect y to be provided - if self.kv_cache is None: - assert ( - y is not None - ), "Must provide y input or use kv_cache to enable streaming decoding" - k, v = calculate_kv(y) - else: - # Expecting the k, v returning here to be the same size of self.kv_cache - # In eager, we expect this predicate to specialize. In export, this will - # become a SymBool so it's not specialized. - k, v, cache_pos = torch.cond( - torch.isnan(y).all().item(), true_fn, false_fn, (y,) - ) # Update key-value cache - self.kv_cache.k_cache.copy_(k) - self.kv_cache.v_cache.copy_(v) - self.kv_cache.cache_pos.copy_(cache_pos) + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + output = self._sdpa(q, k, v, b, s_x) return self.output_proj(output) @@ -352,17 +335,25 @@ def forward( # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. - # [bsz, n_h, s, h_d] - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + # k: [bsz, seq_len, n_kv, 1, h_d] + # v: [bsz, seq_len, n_kv, 1, h_d] + k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) + v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) # Expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim if self.num_heads != self.num_kv_heads: - expand_shape = (-1, -1, self.q_per_kv, -1, -1) - k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) - v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) + v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) + + # [bsz, s, n_h, h_d] + k = k.reshape(bsz, -1, self.num_heads, self.head_dim) + v = v.reshape(bsz, -1, self.num_heads, self.head_dim) + + # [bsz, n_h, s, h_d] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) output = self._attention_fn( q, diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index db940bca3f8..1921be4194b 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -111,11 +111,13 @@ def update( v_out = self.v_cache if self.transpose_cache: - k_out[:, :, self.cache_pos[:seq_len]] = k_val - v_out[:, :, self.cache_pos[:seq_len]] = v_val + pos_mask = torch.arange(k_out.shape[2]) < seq_len + k_out[:, :, self.cache_pos[pos_mask]] = k_val + v_out[:, :, self.cache_pos[pos_mask]] = v_val else: - k_out[:, self.cache_pos[:seq_len]] = k_val - v_out[:, self.cache_pos[:seq_len]] = v_val + pos_mask = torch.arange(k_out.shape[1]) < seq_len + k_out[:, self.cache_pos[pos_mask], :] = k_val + v_out[:, self.cache_pos[pos_mask], :] = v_val # forward cache_pos seq_len positions along # cache_pos starts at (0, 1, 2, 3, 4, 5, ...) @@ -124,7 +126,8 @@ def update( # this allows us to track the current position in the cache # after the last update in a compile-friendly way without any dynamism # e.g. relying on an int size tracker, or re-creating cache_pos every time - self.cache_pos.add_(seq_len) + mask = (seq_len > 0) * 1 + self.cache_pos.add_(seq_len * mask) return k_out, v_out diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 82ee1febf49..759de08217d 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -219,28 +219,29 @@ def test_attention_torch_cond_eager(self): self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len) - # mask mask = self.causal_mask[self.input_pos, :] - # First run + # First run, for the value of the second parameter in the forward, it doesn't matter + # whether it is the same as the first (self attention) or if it is different (cross + # attention). et_res = self.et_mha( self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. + ) tt_res = self.tt_mha( self.x, self.x, mask=mask, input_pos=self.input_pos - ) # Self attention with input pos. + ) self.assertTrue(torch.allclose(et_res, tt_res)) - # Second run test kv cache read. Input pos is [10, 11, ..., 19] + # Second run test kv cache read. Input pos is [10, 11, ..., 19]. next_input_pos = torch.arange(10, 20).unsqueeze(0) empty_y = torch.full_like(self.x, torch.nan) mask = self.causal_mask[next_input_pos, :] et_res = self.et_mha( self.x, empty_y, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. + ) # Cross attention with no y input, ET uses a tensor of empty values. tt_res = self.tt_mha( self.x, None, mask=mask, input_pos=next_input_pos - ) # Self attention with input pos. + ) # Cross attention with no y input, TorchTune uses None. assert_close(et_res, tt_res)