@@ -176,17 +176,32 @@ void resize_sdpa_out(
176176  graph->get_tensor (out)->virtual_resize (graph->sizes_of (q_projected));
177177}
178178
179- void  sdpa_with_kv_cache_impl (
180-     ComputeGraph& graph,
181-     const  std::vector<ValueRef>& args) {
179+ void  update_cache_impl (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
180+   int  arg_idx = 0 ;
181+   const  ValueRef value = args[arg_idx++];
182+   const  ValueRef cache = args[arg_idx++];
183+   const  ValueRef input_pos_symint = args[arg_idx++];
184+   const  ValueRef out = args[arg_idx++];
185+ 
186+   //  Unused variables
187+   (void )out;
188+ 
189+   VK_CHECK_COND (graph.size_at <int32_t >(-4 , value) == 1 );
190+   VK_CHECK_COND (graph.size_at <int32_t >(-4 , cache) == 1 );
191+   VK_CHECK_COND (
192+       graph.size_at <int32_t >(-1 , value) == graph.size_at <int32_t >(-1 , cache));
193+   VK_CHECK_COND (
194+       graph.size_at <int32_t >(-2 , value) == graph.size_at <int32_t >(-2 , cache));
195+ 
196+   add_kv_cache_update_node (graph, input_pos_symint, value, cache);
197+ }
198+ 
199+ void  sdpa_impl (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
182200  int  arg_idx = 0 ;
183201  const  ValueRef q_projected = args[arg_idx++];
184-   const  ValueRef k_projected = args[arg_idx++];
185-   const  ValueRef v_projected = args[arg_idx++];
186-   const  ValueRef k_cache_data = args[arg_idx++];
187-   const  ValueRef v_cache_data = args[arg_idx++];
202+   const  ValueRef k_cache = args[arg_idx++];
203+   const  ValueRef v_cache = args[arg_idx++];
188204  const  ValueRef input_pos_symint = args[arg_idx++];
189-   const  ValueRef sequence_len = args[arg_idx++];
190205  const  ValueRef attn_mask = args[arg_idx++];
191206  const  ValueRef dropout_p = args[arg_idx++];
192207  const  ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195210  //  Output tensors
196211  const  ValueRef out = args[arg_idx++];
197212
198-   //  Unused variables
199-   (void )sequence_len;
200- 
201213  //  Batches must be 1
202214  VK_CHECK_COND (graph.size_at <int32_t >(-4 , q_projected) == 1 );
203-   VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_projected ) == 1 );
204-   VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_projected ) == 1 );
215+   VK_CHECK_COND (graph.size_at <int32_t >(-4 , k_cache ) == 1 );
216+   VK_CHECK_COND (graph.size_at <int32_t >(-4 , v_cache ) == 1 );
205217  //  k and v projected must have the same shape
206-   VK_CHECK_COND (graph.sizes_of (k_projected ) == graph.sizes_of (v_projected ));
218+   VK_CHECK_COND (graph.sizes_of (k_cache ) == graph.sizes_of (v_cache ));
207219  //  head dim must match between tensors
208220  VK_CHECK_COND (
209221      graph.size_at <int32_t >(-1 , q_projected) ==
210-       graph.size_at <int32_t >(-1 , k_projected ));
222+       graph.size_at <int32_t >(-1 , k_cache ));
211223  //  All tensors must have the packed dim be the width (head) dimension
212224  VK_CHECK_COND (graph.packed_dim_of (q_projected) == WHCN::kWidthDim );
213-   VK_CHECK_COND (graph.packed_dim_of (k_projected ) == WHCN::kWidthDim );
214-   VK_CHECK_COND (graph.packed_dim_of (v_projected ) == WHCN::kWidthDim );
225+   VK_CHECK_COND (graph.packed_dim_of (k_cache ) == WHCN::kWidthDim );
226+   VK_CHECK_COND (graph.packed_dim_of (v_cache ) == WHCN::kWidthDim );
215227  //  Some variables are not supported yet
216228  VK_CHECK_COND (
217229      graph.val_is_none (dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222234      graph.val_is_none (is_causal) || graph.extract_scalar <bool >(is_causal));
223235  VK_CHECK_COND (graph.val_is_none (attn_mask));
224236
225-   const  ValueRef k_cache =
226-       prepack_standard_like (graph, k_cache_data, q_projected);
227-   const  ValueRef v_cache =
228-       prepack_standard_like (graph, v_cache_data, q_projected);
229- 
230237  const  int32_t  max_seq_len = graph.size_at <int32_t >(1 , k_cache);
231238
232-   add_kv_cache_update_node (graph, input_pos_symint, k_projected, k_cache);
233-   add_kv_cache_update_node (graph, input_pos_symint, v_projected, v_cache);
234- 
235239  //  Slice caches from 0 to input_pos + sequence_len
236240  const  ValueRef k_cache_sliced = graph.add_tensor_view (k_cache);
237241  const  ValueRef v_cache_sliced = graph.add_tensor_view (v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257261
258262  //  Repeat interleave
259263  const  int64_t  num_heads = graph.size_at <int64_t >(2 , q_projected);
260-   const  int64_t  num_kv_heads = graph.size_at <int64_t >(2 , k_projected );
264+   const  int64_t  num_kv_heads = graph.size_at <int64_t >(2 , k_cache );
261265
262266  const  ValueRef num_repeats =
263267      graph.add_scalar <int64_t >(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331335      new  ExecuteNode (resize_sdpa_out, {q_projected, out}));
332336}
333337
338+ void  sdpa_with_kv_cache_impl (
339+     ComputeGraph& graph,
340+     const  std::vector<ValueRef>& args) {
341+   int  arg_idx = 0 ;
342+   const  ValueRef q_projected = args[arg_idx++];
343+   const  ValueRef k_projected = args[arg_idx++];
344+   const  ValueRef v_projected = args[arg_idx++];
345+   const  ValueRef k_cache_data = args[arg_idx++];
346+   const  ValueRef v_cache_data = args[arg_idx++];
347+   const  ValueRef input_pos_symint = args[arg_idx++];
348+   const  ValueRef sequence_len = args[arg_idx++];
349+   const  ValueRef attn_mask = args[arg_idx++];
350+   const  ValueRef dropout_p = args[arg_idx++];
351+   const  ValueRef is_causal = args[arg_idx++];
352+   const  ValueRef scale = args[arg_idx++];
353+ 
354+   //  Output tensors
355+   const  ValueRef out = args[arg_idx++];
356+ 
357+   (void )sequence_len;
358+ 
359+   const  ValueRef k_cache =
360+       prepack_standard_like (graph, k_cache_data, q_projected);
361+   const  ValueRef v_cache =
362+       prepack_standard_like (graph, v_cache_data, q_projected);
363+ 
364+   update_cache_impl (graph, {k_projected, k_cache, input_pos_symint, -1 });
365+   update_cache_impl (graph, {v_projected, v_cache, input_pos_symint, -1 });
366+ 
367+   sdpa_impl (
368+       graph,
369+       {q_projected,
370+        k_cache,
371+        v_cache,
372+        input_pos_symint,
373+        attn_mask,
374+        dropout_p,
375+        is_causal,
376+        scale,
377+        out});
378+ }
379+ 
334380REGISTER_OPERATORS {
335381  VK_REGISTER_OP (sdpa_with_kv_cache.default , sdpa_with_kv_cache_impl);
382+   VK_REGISTER_OP (update_cache.default , update_cache_impl);
383+   VK_REGISTER_OP (llama.custom_sdpa .default , sdpa_impl);
336384}
337385
338386} //  namespace vkcompute
0 commit comments