@@ -176,17 +176,32 @@ void resize_sdpa_out(
176176 graph->get_tensor (out)->virtual_resize (graph->sizes_of (q_projected));
177177}
178178
179- void sdpa_with_kv_cache_impl (
180- ComputeGraph& graph,
181- const std::vector<ValueRef>& args) {
179+ void update_cache_impl (ComputeGraph& graph, const std::vector<ValueRef>& args) {
180+ int arg_idx = 0 ;
181+ const ValueRef value = args[arg_idx++];
182+ const ValueRef cache = args[arg_idx++];
183+ const ValueRef input_pos_symint = args[arg_idx++];
184+ const ValueRef out = args[arg_idx++];
185+
186+ // Unused variables
187+ (void )out;
188+
189+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , value) == 1 );
190+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , cache) == 1 );
191+ VK_CHECK_COND (
192+ graph.size_at <int32_t >(-1 , value) == graph.size_at <int32_t >(-1 , cache));
193+ VK_CHECK_COND (
194+ graph.size_at <int32_t >(-2 , value) == graph.size_at <int32_t >(-2 , cache));
195+
196+ add_kv_cache_update_node (graph, input_pos_symint, value, cache);
197+ }
198+
199+ void sdpa_impl (ComputeGraph& graph, const std::vector<ValueRef>& args) {
182200 int arg_idx = 0 ;
183201 const ValueRef q_projected = args[arg_idx++];
184- const ValueRef k_projected = args[arg_idx++];
185- const ValueRef v_projected = args[arg_idx++];
186- const ValueRef k_cache_data = args[arg_idx++];
187- const ValueRef v_cache_data = args[arg_idx++];
202+ const ValueRef k_cache = args[arg_idx++];
203+ const ValueRef v_cache = args[arg_idx++];
188204 const ValueRef input_pos_symint = args[arg_idx++];
189- const ValueRef sequence_len = args[arg_idx++];
190205 const ValueRef attn_mask = args[arg_idx++];
191206 const ValueRef dropout_p = args[arg_idx++];
192207 const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195210 // Output tensors
196211 const ValueRef out = args[arg_idx++];
197212
198- // Unused variables
199- (void )sequence_len;
200-
201213 // Batches must be 1
202214 VK_CHECK_COND (graph.size_at <int32_t >(-4 , q_projected) == 1 );
203- VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_projected ) == 1 );
204- VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_projected ) == 1 );
215+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_cache ) == 1 );
216+ VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_cache ) == 1 );
205217 // k and v projected must have the same shape
206- VK_CHECK_COND (graph.sizes_of (k_projected ) == graph.sizes_of (v_projected ));
218+ VK_CHECK_COND (graph.sizes_of (k_cache ) == graph.sizes_of (v_cache ));
207219 // head dim must match between tensors
208220 VK_CHECK_COND (
209221 graph.size_at <int32_t >(-1 , q_projected) ==
210- graph.size_at <int32_t >(-1 , k_projected ));
222+ graph.size_at <int32_t >(-1 , k_cache ));
211223 // All tensors must have the packed dim be the width (head) dimension
212224 VK_CHECK_COND (graph.packed_dim_of (q_projected) == WHCN::kWidthDim );
213- VK_CHECK_COND (graph.packed_dim_of (k_projected ) == WHCN::kWidthDim );
214- VK_CHECK_COND (graph.packed_dim_of (v_projected ) == WHCN::kWidthDim );
225+ VK_CHECK_COND (graph.packed_dim_of (k_cache ) == WHCN::kWidthDim );
226+ VK_CHECK_COND (graph.packed_dim_of (v_cache ) == WHCN::kWidthDim );
215227 // Some variables are not supported yet
216228 VK_CHECK_COND (
217229 graph.val_is_none (dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222234 graph.val_is_none (is_causal) || graph.extract_scalar <bool >(is_causal));
223235 VK_CHECK_COND (graph.val_is_none (attn_mask));
224236
225- const ValueRef k_cache =
226- prepack_standard_like (graph, k_cache_data, q_projected);
227- const ValueRef v_cache =
228- prepack_standard_like (graph, v_cache_data, q_projected);
229-
230237 const int32_t max_seq_len = graph.size_at <int32_t >(1 , k_cache);
231238
232- add_kv_cache_update_node (graph, input_pos_symint, k_projected, k_cache);
233- add_kv_cache_update_node (graph, input_pos_symint, v_projected, v_cache);
234-
235239 // Slice caches from 0 to input_pos + sequence_len
236240 const ValueRef k_cache_sliced = graph.add_tensor_view (k_cache);
237241 const ValueRef v_cache_sliced = graph.add_tensor_view (v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257261
258262 // Repeat interleave
259263 const int64_t num_heads = graph.size_at <int64_t >(2 , q_projected);
260- const int64_t num_kv_heads = graph.size_at <int64_t >(2 , k_projected );
264+ const int64_t num_kv_heads = graph.size_at <int64_t >(2 , k_cache );
261265
262266 const ValueRef num_repeats =
263267 graph.add_scalar <int64_t >(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331335 new ExecuteNode (resize_sdpa_out, {q_projected, out}));
332336}
333337
338+ void sdpa_with_kv_cache_impl (
339+ ComputeGraph& graph,
340+ const std::vector<ValueRef>& args) {
341+ int arg_idx = 0 ;
342+ const ValueRef q_projected = args[arg_idx++];
343+ const ValueRef k_projected = args[arg_idx++];
344+ const ValueRef v_projected = args[arg_idx++];
345+ const ValueRef k_cache_data = args[arg_idx++];
346+ const ValueRef v_cache_data = args[arg_idx++];
347+ const ValueRef input_pos_symint = args[arg_idx++];
348+ const ValueRef sequence_len = args[arg_idx++];
349+ const ValueRef attn_mask = args[arg_idx++];
350+ const ValueRef dropout_p = args[arg_idx++];
351+ const ValueRef is_causal = args[arg_idx++];
352+ const ValueRef scale = args[arg_idx++];
353+
354+ // Output tensors
355+ const ValueRef out = args[arg_idx++];
356+
357+ (void )sequence_len;
358+
359+ const ValueRef k_cache =
360+ prepack_standard_like (graph, k_cache_data, q_projected);
361+ const ValueRef v_cache =
362+ prepack_standard_like (graph, v_cache_data, q_projected);
363+
364+ update_cache_impl (graph, {k_projected, k_cache, input_pos_symint, -1 });
365+ update_cache_impl (graph, {v_projected, v_cache, input_pos_symint, -1 });
366+
367+ sdpa_impl (
368+ graph,
369+ {q_projected,
370+ k_cache,
371+ v_cache,
372+ input_pos_symint,
373+ attn_mask,
374+ dropout_p,
375+ is_causal,
376+ scale,
377+ out});
378+ }
379+
334380REGISTER_OPERATORS {
335381 VK_REGISTER_OP (sdpa_with_kv_cache.default , sdpa_with_kv_cache_impl);
382+ VK_REGISTER_OP (update_cache.default , update_cache_impl);
383+ VK_REGISTER_OP (llama.custom_sdpa .default , sdpa_impl);
336384}
337385
338386} // namespace vkcompute
0 commit comments