@@ -273,7 +273,6 @@ Tensor& flash_attention_kernel_out(
273273 Format [n_layers, batch size, max_seq_len, num heads, head dim]
274274 ....
275275 @param[in] start_pos: sequence position
276- @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
277276*/
278277Tensor& custom_sdpa_out (
279278 RuntimeContext& ctx,
@@ -306,63 +305,7 @@ Tensor& custom_sdpa_out(
306305 const int64_t seq_len = q.size (1 );
307306 auto q_seq_len = q.size (1 );
308307
309- // Refactor the following into create_view util perhaps using
310- // TensorPtr
311- std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim >
312- sliced_key_dim_order{0 , 1 , 2 , 3 };
313- std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim >
314- sliced_key_sizes;
315- sliced_key_sizes[0 ] = k.size (0 );
316- sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
317- sliced_key_sizes[2 ] = k.size (2 );
318- sliced_key_sizes[3 ] = k.size (3 );
319- std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim >
320- sliced_key_strides;
321- dim_order_to_stride_nocheck (
322- sliced_key_sizes.data (),
323- sliced_key_dim_order.data (),
324- sdpa::impl::kKVDim ,
325- sliced_key_strides.data ());
326- // since the cache is sliced, the batch stride needs to stay the same.
327- sliced_key_strides[0 ] = k.strides ()[0 ];
328- void * key_cache_data = k.mutable_data_ptr ();
329- TensorImpl k_impl = TensorImpl (
330- k.scalar_type (),
331- sdpa::impl::kKVDim ,
332- sliced_key_sizes.data (),
333- key_cache_data,
334- sliced_key_dim_order.data (),
335- sliced_key_strides.data (),
336- TensorShapeDynamism::STATIC);
337- Tensor sliced_key_cache (&k_impl);
338-
339- std::array<::executorch::aten::DimOrderType, sdpa::impl::kKVDim >
340- sliced_value_dim_order{0 , 1 , 2 , 3 };
341- std::array<::executorch::aten::SizesType, sdpa::impl::kKVDim >
342- sliced_value_sizes;
343- sliced_value_sizes[0 ] = v.size (0 );
344- sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
345- sliced_value_sizes[2 ] = v.size (2 );
346- sliced_value_sizes[3 ] = v.size (3 );
347- std::array<::executorch::aten::StridesType, sdpa::impl::kKVDim >
348- sliced_value_strides;
349- dim_order_to_stride_nocheck (
350- sliced_value_sizes.data (),
351- sliced_value_dim_order.data (),
352- sdpa::impl::kKVDim ,
353- sliced_value_strides.data ());
354- // since the cache is sliced, the batch stride needs to stay the same.
355- sliced_value_strides[0 ] = v.strides ()[0 ];
356- void * value_cache_data = v.mutable_data_ptr ();
357- TensorImpl value_impl = TensorImpl (
358- v.scalar_type (),
359- sdpa::impl::kKVDim ,
360- sliced_value_sizes.data (),
361- value_cache_data,
362- sliced_value_dim_order.data (),
363- sliced_value_strides.data (),
364- TensorShapeDynamism::STATIC);
365- Tensor sliced_value_cache (&value_impl);
308+ const int64_t num_keys_for_causal_attention = start_pos + seq_len;
366309
367310 ET_KERNEL_CHECK (
368311 ctx,
@@ -380,38 +323,41 @@ Tensor& custom_sdpa_out(
380323 sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
381324 output,
382325 q,
383- sliced_key_cache ,
384- sliced_value_cache ,
326+ k ,
327+ v ,
385328 dropout_p,
386329 is_causal,
387330 attn_mask,
388331 scale,
389332 true , /* is_seq_at_dim_1 */
390- start_pos);
333+ start_pos,
334+ num_keys_for_causal_attention);
391335 } else if (q_seq_len >= 192 ) {
392336 sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
393337 output,
394338 q,
395- sliced_key_cache ,
396- sliced_value_cache ,
339+ k ,
340+ v ,
397341 dropout_p,
398342 is_causal,
399343 attn_mask,
400344 scale,
401345 true , /* is_seq_at_dim_1 */
402- start_pos);
346+ start_pos,
347+ num_keys_for_causal_attention);
403348 } else {
404349 sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
405350 output,
406351 q,
407- sliced_key_cache ,
408- sliced_value_cache ,
352+ k ,
353+ v ,
409354 dropout_p,
410355 is_causal,
411356 attn_mask,
412357 scale,
413358 true , /* is_seq_at_dim_1 */
414- start_pos);
359+ start_pos,
360+ num_keys_for_causal_attention);
415361 }
416362 });
417363 return output;
0 commit comments