Skip to content

Commit 3468f0c

Browse files
committed
Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 3a6b545 + 148354d commit 3468f0c

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

.ci/scripts/test_llama.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ fi
112112

113113
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114114
QUANTIZE_KV_CACHE=ON
115+
# quantize_kv cache transform uses custom kv cache update op
116+
CUSTOM=ON
115117
else
116118
QUANTIZE_KV_CACHE=OFF
117119
fi

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
n_heads,
3838
head_dim,
3939
cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
40+
use_custom_update_cache_op: bool = False,
4041
):
4142
super().__init__()
4243
if cache_type not in (
@@ -48,6 +49,7 @@ def __init__(
4849
)
4950

5051
# For now supporting int8 only
52+
self.use_custom_update_cache_op = True
5153
self.quantized_cache_dtype = torch.int8
5254
self.cache_fp_type = torch.float32
5355
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
@@ -103,24 +105,25 @@ def update(self, input_pos, k_val, v_val):
103105

104106
quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
105107

106-
# Right now using custom ops on this path.
107-
# In future we can update custom op to handle transposed cache
108-
# as well.
109-
# Note that we may have to revert this change if other ET
110-
# backends such as QNN want to use quantized cache, with dynamic shape,
111-
# instead of quantizing on their own.
112-
# But until this opting for code simplicity
113-
start_pos = input_pos[0].item()
114-
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
115-
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
116-
_ = torch.ops.llama.update_cache(
117-
k_zero_points, self.k_cache_zero_points, start_pos
118-
)
119-
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
120-
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
121-
_ = torch.ops.llama.update_cache(
122-
v_zero_points, self.v_cache_zero_points, start_pos
123-
)
108+
if self.use_custom_update_cache_op:
109+
start_pos = input_pos[0].item()
110+
_ = torch.ops.llama.update_cache(quantized_k_val, self.k_cache, start_pos)
111+
_ = torch.ops.llama.update_cache(k_scales, self.k_cache_scales, start_pos)
112+
_ = torch.ops.llama.update_cache(
113+
k_zero_points, self.k_cache_zero_points, start_pos
114+
)
115+
_ = torch.ops.llama.update_cache(quantized_v_val, self.v_cache, start_pos)
116+
_ = torch.ops.llama.update_cache(v_scales, self.v_cache_scales, start_pos)
117+
_ = torch.ops.llama.update_cache(
118+
v_zero_points, self.v_cache_zero_points, start_pos
119+
)
120+
else:
121+
self.k_cache[:, :, input_pos] = quantized_k_val
122+
self.k_cache_scales[:, :, input_pos] = k_scales
123+
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
124+
self.v_cache[:, :, input_pos] = quantized_v_val
125+
self.v_cache_scales[:, :, input_pos] = v_scales
126+
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
124127

125128
k_out = torch.ops.quantized_decomposed.dequantize_per_token(
126129
self.k_cache,
@@ -148,7 +151,12 @@ def update(self, input_pos, k_val, v_val):
148151
return k_out.transpose(1, 2), v_out.transpose(1, 2)
149152

150153
@classmethod
151-
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
154+
def from_float(
155+
cls,
156+
kv_cache,
157+
cache_type: QuantizedCacheType,
158+
use_custom_update_cache_op: bool = False,
159+
):
152160
max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape
153161
if isinstance(kv_cache, CustomKVCache):
154162
# If replacing custom kv cache, then the shape is [B, S, H, D]
@@ -159,6 +167,7 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
159167
n_heads,
160168
head_dim,
161169
cache_type,
170+
use_custom_update_cache_op,
162171
)
163172

164173

@@ -199,6 +208,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
199208
module,
200209
name,
201210
QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
211+
use_custom_update_cache_op=True,
202212
)
203213
else:
204214
replace_kv_cache_with_quantized_kv_cache(child)

0 commit comments

Comments
 (0)