@@ -96,6 +96,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
9696 const optional<Tensor> k_scales,
9797 const optional<Tensor> v_zero_points,
9898 const optional<Tensor> v_scales,
99+ const bool is_seq_at_dim_2,
99100 Tensor& output);
100101
101102at::Tensor custom_quantized_sdpa_aten (
@@ -115,7 +116,8 @@ at::Tensor custom_quantized_sdpa_aten(
115116 const std::optional<at::Tensor>& k_zero_points,
116117 const std::optional<at::Tensor>& k_scales,
117118 const std::optional<at::Tensor>& v_zero_points,
118- const std::optional<at::Tensor>& v_scales);
119+ const std::optional<at::Tensor>& v_scales,
120+ const bool is_seq_at_dim_2);
119121#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
120122
121123Tensor& update_cache_out_no_context (
@@ -258,6 +260,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
258260 const optional<Tensor> k_scales,
259261 const optional<Tensor> v_zero_points,
260262 const optional<Tensor> v_scales,
263+ const bool is_seq_at_dim_2,
261264 Tensor& output) {
262265 executorch::aten::RuntimeContext context{};
263266 return torch::executor::native::custom_quantized_sdpa_out (
@@ -276,6 +279,7 @@ Tensor& custom_quantized_sdpa_out_no_context(
276279 k_scales,
277280 v_zero_points,
278281 v_scales,
282+ is_seq_at_dim_2,
279283 output);
280284}
281285
@@ -296,9 +300,10 @@ at::Tensor custom_quantized_sdpa_aten(
296300 const std::optional<at::Tensor>& k_zero_points,
297301 const std::optional<at::Tensor>& k_scales,
298302 const std::optional<at::Tensor>& v_zero_points,
299- const std::optional<at::Tensor>& v_scales) {
303+ const std::optional<at::Tensor>& v_scales,
304+ const bool is_seq_at_dim_2) {
300305 auto output = at::empty (q.sizes ());
301- WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 14 )
306+ WRAP_TO_ATEN (custom_quantized_sdpa_out_no_context, 15 )
302307 (q,
303308 k,
304309 v,
@@ -313,6 +318,7 @@ at::Tensor custom_quantized_sdpa_aten(
313318 k_scales,
314319 v_zero_points,
315320 v_scales,
321+ is_seq_at_dim_2,
316322 output);
317323 return output;
318324}
@@ -371,13 +377,13 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
371377 " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
372378 " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
373379 " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
374- " Tensor? v_scales=None) -> Tensor" );
380+ " Tensor? v_scales=None, bool is_seq_at_dim_2=False ) -> Tensor" );
375381 m.def (
376382 " custom_quantized_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
377383 " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
378384 " float? scale=None, Tensor? q_zero_points=None, Tensor? q_scales=None, "
379385 " Tensor? k_zero_points=None, Tensor? k_scales=None, Tensor? v_zero_points=None, "
380- " Tensor? v_scales=None, *, Tensor(a!) out) -> Tensor(a!)" );
386+ " Tensor? v_scales=None, bool is_seq_at_dim_2=False, *, Tensor(a!) out) -> Tensor(a!)" );
381387#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
382388}
383389
@@ -404,6 +410,6 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
404410 m.impl (
405411 " custom_quantized_sdpa.out" ,
406412 WRAP_TO_ATEN (
407- torch::executor::native::custom_quantized_sdpa_out_no_context, 14 ));
413+ torch::executor::native::custom_quantized_sdpa_out_no_context, 15 ));
408414#endif // ENABLE_CUSTOM_QUANTIZED_SDPA
409415}
0 commit comments