@@ -224,7 +224,7 @@ void cpu_flash_attention(
224224 bool is_causal,
225225 const optional<Tensor>& attn_mask,
226226 const optional<double >& scale,
227- bool is_with_kv_cache = false ,
227+ bool is_seq_at_dim_1 = false ,
228228 const int64_t start_pos = 0 ) {
229229 (void )dropout_p;
230230 // Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
@@ -265,7 +265,7 @@ void cpu_flash_attention(
265265 int64_t kvSize = value.size (2 );
266266 int64_t num_heads_kv = key.size (1 );
267267
268- if (is_with_kv_cache ) {
268+ if (is_seq_at_dim_1 ) {
269269 num_head = query.size (2 );
270270 num_heads_kv = key.size (2 );
271271 qSize = query.size (1 );
@@ -311,7 +311,7 @@ void cpu_flash_attention(
311311 int64_t qStrideH = strides[1 ];
312312 int64_t qStrideM = strides[2 ];
313313
314- if (is_with_kv_cache ) {
314+ if (is_seq_at_dim_1 ) {
315315 qStrideH = strides[2 ];
316316 qStrideM = strides[1 ];
317317 }
@@ -321,7 +321,7 @@ void cpu_flash_attention(
321321 int64_t kStrideH = strides[1 ];
322322 int64_t kStrideN = strides[2 ];
323323
324- if (is_with_kv_cache ) {
324+ if (is_seq_at_dim_1 ) {
325325 kStrideH = strides[2 ];
326326 kStrideN = strides[1 ];
327327 }
@@ -331,7 +331,7 @@ void cpu_flash_attention(
331331 int64_t vStrideH = strides[1 ];
332332 int64_t vStrideN = strides[2 ];
333333
334- if (is_with_kv_cache ) {
334+ if (is_seq_at_dim_1 ) {
335335 vStrideH = strides[2 ];
336336 vStrideN = strides[1 ];
337337 }
@@ -341,7 +341,7 @@ void cpu_flash_attention(
341341 int64_t oStrideH = strides[1 ];
342342 int64_t oStrideM = strides[2 ];
343343
344- if (is_with_kv_cache ) {
344+ if (is_seq_at_dim_1 ) {
345345 oStrideH = strides[2 ];
346346 oStrideM = strides[1 ];
347347 }
@@ -776,7 +776,6 @@ Tensor& custom_sdpa_out(
776776 const Tensor& k,
777777 const Tensor& v,
778778 const int64_t start_pos,
779- const int64_t seq_len,
780779 const optional<Tensor>& attn_mask,
781780 const double dropout_p,
782781 const bool is_causal,
@@ -792,6 +791,7 @@ Tensor& custom_sdpa_out(
792791
793792 ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
794793
794+ const int64_t seq_len = q.size (1 );
795795 auto q_seq_len = q.size (1 );
796796
797797 // Refactor the following into create_view util perhaps using
@@ -870,7 +870,7 @@ Tensor& custom_sdpa_out(
870870 is_causal,
871871 attn_mask,
872872 scale,
873- true ,
873+ true , /* is_seq_at_dim_1 */
874874 start_pos);
875875 } else if (q_seq_len >= 192 ) {
876876 cpu_flash_attention<CTYPE, 64 , 512 >(
@@ -882,7 +882,7 @@ Tensor& custom_sdpa_out(
882882 is_causal,
883883 attn_mask,
884884 scale,
885- true ,
885+ true , /* is_seq_at_dim_1 */
886886 start_pos);
887887 } else {
888888 cpu_flash_attention<CTYPE, 32 , 512 >(
@@ -894,7 +894,7 @@ Tensor& custom_sdpa_out(
894894 is_causal,
895895 attn_mask,
896896 scale,
897- true ,
897+ true , /* is_seq_at_dim_1 */
898898 start_pos);
899899 }
900900 });
@@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out(
10171017 key_cache,
10181018 value_cache,
10191019 start_pos,
1020- seq_len,
10211020 attn_mask,
10221021 dropout_p,
10231022 is_causal,
0 commit comments