@@ -224,7 +224,7 @@ void cpu_flash_attention(
224224    bool  is_causal,
225225    const  optional<Tensor>& attn_mask,
226226    const  optional<double >& scale,
227-     bool  is_with_kv_cache  = false ,
227+     bool  is_seq_at_dim_1  = false ,
228228    const  int64_t  start_pos = 0 ) {
229229  (void )dropout_p;
230230  //  Query (Batch x Num_heads  x Q_seq_len  x Dim_per_head)
@@ -265,7 +265,7 @@ void cpu_flash_attention(
265265  int64_t  kvSize = value.size (2 );
266266  int64_t  num_heads_kv = key.size (1 );
267267
268-   if  (is_with_kv_cache ) {
268+   if  (is_seq_at_dim_1 ) {
269269    num_head = query.size (2 );
270270    num_heads_kv = key.size (2 );
271271    qSize = query.size (1 );
@@ -311,7 +311,7 @@ void cpu_flash_attention(
311311  int64_t  qStrideH = strides[1 ];
312312  int64_t  qStrideM = strides[2 ];
313313
314-   if  (is_with_kv_cache ) {
314+   if  (is_seq_at_dim_1 ) {
315315    qStrideH = strides[2 ];
316316    qStrideM = strides[1 ];
317317  }
@@ -321,7 +321,7 @@ void cpu_flash_attention(
321321  int64_t  kStrideH  = strides[1 ];
322322  int64_t  kStrideN  = strides[2 ];
323323
324-   if  (is_with_kv_cache ) {
324+   if  (is_seq_at_dim_1 ) {
325325    kStrideH  = strides[2 ];
326326    kStrideN  = strides[1 ];
327327  }
@@ -331,7 +331,7 @@ void cpu_flash_attention(
331331  int64_t  vStrideH = strides[1 ];
332332  int64_t  vStrideN = strides[2 ];
333333
334-   if  (is_with_kv_cache ) {
334+   if  (is_seq_at_dim_1 ) {
335335    vStrideH = strides[2 ];
336336    vStrideN = strides[1 ];
337337  }
@@ -341,7 +341,7 @@ void cpu_flash_attention(
341341  int64_t  oStrideH = strides[1 ];
342342  int64_t  oStrideM = strides[2 ];
343343
344-   if  (is_with_kv_cache ) {
344+   if  (is_seq_at_dim_1 ) {
345345    oStrideH = strides[2 ];
346346    oStrideM = strides[1 ];
347347  }
@@ -776,7 +776,6 @@ Tensor& custom_sdpa_out(
776776    const  Tensor& k,
777777    const  Tensor& v,
778778    const  int64_t  start_pos,
779-     const  int64_t  seq_len,
780779    const  optional<Tensor>& attn_mask,
781780    const  double  dropout_p,
782781    const  bool  is_causal,
@@ -792,6 +791,7 @@ Tensor& custom_sdpa_out(
792791
793792  ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor"  );
794793
794+   const  int64_t  seq_len = q.size (1 );
795795  auto  q_seq_len = q.size (1 );
796796
797797  //  Refactor the following into create_view util perhaps using
@@ -870,7 +870,7 @@ Tensor& custom_sdpa_out(
870870          is_causal,
871871          attn_mask,
872872          scale,
873-           true ,
873+           true ,  /*  is_seq_at_dim_1  */ 
874874          start_pos);
875875    } else  if  (q_seq_len >= 192 ) {
876876      cpu_flash_attention<CTYPE, 64 , 512 >(
@@ -882,7 +882,7 @@ Tensor& custom_sdpa_out(
882882          is_causal,
883883          attn_mask,
884884          scale,
885-           true ,
885+           true ,  /*  is_seq_at_dim_1  */ 
886886          start_pos);
887887    } else  {
888888      cpu_flash_attention<CTYPE, 32 , 512 >(
@@ -894,7 +894,7 @@ Tensor& custom_sdpa_out(
894894          is_causal,
895895          attn_mask,
896896          scale,
897-           true ,
897+           true ,  /*  is_seq_at_dim_1  */ 
898898          start_pos);
899899    }
900900  });
@@ -1017,7 +1017,6 @@ Tensor& sdpa_with_kv_cache_out(
10171017      key_cache,
10181018      value_cache,
10191019      start_pos,
1020-       seq_len,
10211020      attn_mask,
10221021      dropout_p,
10231022      is_causal,
0 commit comments