@@ -82,6 +82,51 @@ bool validate_flash_attention_args(
8282  return  true ;
8383}
8484
85+ bool  validate_cache_quant_params_args (
86+     const  Tensor& t,
87+     const  Tensor& t_zero_points,
88+     const  Tensor& t_scales) {
89+   ET_CHECK_OR_RETURN_FALSE (
90+       t.dim () == t_scales.dim (),
91+       " Quantized tensor and scales must have the same number of dimensions" 
92+   ET_CHECK_OR_RETURN_FALSE (
93+       t.dim () == t_zero_points.dim (),
94+       " Quantized tensor and scales must have the same number of dimensions" 
95+ 
96+   ET_CHECK_OR_RETURN_FALSE (
97+       (t.scalar_type () == ScalarType::Char), " Tensor must be of int8_t type" 
98+ 
99+   ET_CHECK_OR_RETURN_FALSE (
100+       (t_scales.scalar_type () == ScalarType::Float),
101+       " Scales tensor must be of float type" 
102+ 
103+   ET_CHECK_OR_RETURN_FALSE (
104+       (t_zero_points.scalar_type () == ScalarType::Char),
105+       " Zero points tensor must be of int8_t type" 
106+ 
107+   //  Sizes
108+   for  (int64_t  i = 0 ; i < t.dim () - 1 ; i++) {
109+     ET_CHECK_OR_RETURN_FALSE (
110+         (t.size (i) == t_scales.size (i)),
111+         " Quantized tensor and scales have different shape" 
112+         " at dim: %" " , t: %zd, t_scales: %zd" 
113+         i,
114+         t.size (i),
115+         t_scales.size (i));
116+     ;
117+     ET_CHECK_OR_RETURN_FALSE (
118+         (t.size (i) == t_zero_points.size (i)),
119+         " Quantized tensor and zero points have different shape" 
120+         " at dim: %" " , t: %zd, t_scales: %zd" 
121+         i,
122+         t.size (i),
123+         t_zero_points.size (i));
124+     ;
125+   }
126+ 
127+   return  true ;
128+ }
129+ 
85130bool  validate_cache_params (
86131    const  Tensor& k_cache,
87132    const  Tensor& v_cache,
@@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out(
233278              dropout_p,
234279              is_causal,
235280              attn_mask,
236-               scale);
281+               scale,
282+               nullopt ,
283+               nullopt ,
284+               nullopt ,
285+               nullopt ,
286+               nullopt ,
287+               nullopt );
237288        } else  if  (q_seq_len >= 192 ) {
238289          sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
239290              output,
@@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out(
243294              dropout_p,
244295              is_causal,
245296              attn_mask,
246-               scale);
297+               scale,
298+               nullopt ,
299+               nullopt ,
300+               nullopt ,
301+               nullopt ,
302+               nullopt ,
303+               nullopt );
247304        } else  {
248305          sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
249306              output,
@@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out(
253310              dropout_p,
254311              is_causal,
255312              attn_mask,
256-               scale);
313+               scale,
314+               nullopt ,
315+               nullopt ,
316+               nullopt ,
317+               nullopt ,
318+               nullopt ,
319+               nullopt );
257320        }
258321      });
259322  return  output;
260323}
261324
262- /* 
263-   Input params 
264-   @param[in] q_projected Projected query with query weights. 
265-   Format [n_layers, batch size, seq_len, num heads, head dim] 
266-   @param[in] k_projected Projected query with key weights. 
267-   Format [n_layers, batch size, seq_len, num heads, head dim] 
268-   @param[in] v_projected Projected query with value weights. 
269-   Format [n_layers, batch size, seq_len, num heads, head dim] 
270-   @param[in] key_cache Cache of previous k_projected. 
271-   Format [n_layers, batch size, max_seq_len, num heads, head dim] 
272-   @param[in] key_cache Cache of previous v_projected. 
273-   Format [n_layers, batch size, max_seq_len, num heads, head dim] 
274-   .... 
275-   @param[in] start_pos: sequence position 
276- */ 
277- Tensor& custom_sdpa_out (
325+ Tensor& custom_sdpa_out_impl (
278326    RuntimeContext& ctx,
279327    const  Tensor& q,
280328    const  Tensor& k,
@@ -285,7 +333,13 @@ Tensor& custom_sdpa_out(
285333    const  bool  is_causal,
286334    //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
287335    const  optional<double > scale,
288-     Tensor& output) {
336+     Tensor& output,
337+     const  optional<Tensor>& q_zero_points = nullopt ,
338+     const  optional<Tensor>& q_scales = nullopt ,
339+     const  optional<Tensor>& k_zero_points = nullopt ,
340+     const  optional<Tensor>& k_scales = nullopt ,
341+     const  optional<Tensor>& v_zero_points = nullopt ,
342+     const  optional<Tensor>& v_scales = nullopt ) {
289343  ET_KERNEL_CHECK_MSG (
290344      ctx,
291345      !attn_mask.has_value () || !is_causal,
@@ -300,6 +354,40 @@ Tensor& custom_sdpa_out(
300354      output,
301355      " Invalid arguments" 
302356
357+   bool  is_seq_at_dim_1{true };
358+   if  (q.scalar_type () == ScalarType::Char) {
359+     is_seq_at_dim_1 = false ;
360+     ET_KERNEL_CHECK_MSG (
361+         ctx,
362+         q_scales.has_value () && q_zero_points.has_value () &&
363+             k_scales.has_value () && k_zero_points.has_value () &&
364+             q_scales.has_value () && q_zero_points.has_value (),
365+         InvalidArgument,
366+         output,
367+         " If q is quantized, k and v must be quantized as well" 
368+     ET_KERNEL_CHECK_MSG (
369+         ctx,
370+         validate_cache_quant_params_args (
371+             q, q_zero_points.value (), q_scales.value ()),
372+         InvalidArgument,
373+         output,
374+         " Invalid arguments for quantized query" 
375+     ET_KERNEL_CHECK_MSG (
376+         ctx,
377+         validate_cache_quant_params_args (
378+             k, k_zero_points.value (), k_scales.value ()),
379+         InvalidArgument,
380+         output,
381+         " Invalid arguments for quantized key" 
382+     ET_KERNEL_CHECK_MSG (
383+         ctx,
384+         validate_cache_quant_params_args (
385+             v, v_zero_points.value (), v_scales.value ()),
386+         InvalidArgument,
387+         output,
388+         " Invalid arguments for quantized value" 
389+   }
390+ 
303391  ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" 
304392
305393  const  int64_t  seq_len = q.size (1 );
@@ -315,53 +403,103 @@ Tensor& custom_sdpa_out(
315403
316404  //  TODO(task): replace the template param selection logic
317405  //  with whatever apprpriately makes more sense for
318-   ET_SWITCH_FLOAT_TYPES (q.scalar_type (), ctx, " flash_attention" 
319-     //  TODO we need to re-evaluate this for ARM CPUs
320-     //  And there can be many so instead of templatizing
321-     //  we might consider another appraoch
322-     if  (q_seq_len >= 768 ) {
323-       sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
324-           output,
325-           q,
326-           k,
327-           v,
328-           dropout_p,
329-           is_causal,
330-           attn_mask,
331-           scale,
332-           true , /*  is_seq_at_dim_1 */ 
333-           start_pos,
334-           num_keys_for_causal_attention);
335-     } else  if  (q_seq_len >= 192 ) {
336-       sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
337-           output,
338-           q,
339-           k,
340-           v,
341-           dropout_p,
342-           is_causal,
343-           attn_mask,
344-           scale,
345-           true , /*  is_seq_at_dim_1 */ 
346-           start_pos,
347-           num_keys_for_causal_attention);
348-     } else  {
349-       sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
350-           output,
351-           q,
352-           k,
353-           v,
354-           dropout_p,
355-           is_causal,
356-           attn_mask,
357-           scale,
358-           true , /*  is_seq_at_dim_1 */ 
359-           start_pos,
360-           num_keys_for_causal_attention);
361-     }
362-   });
406+   ET_SWITCH_FLOAT_TYPES (
407+       output.scalar_type (), ctx, " flash_attention" 
408+         //  TODO we need to re-evaluate this for ARM CPUs
409+         //  And there can be many so instead of templatizing
410+         //  we might consider another appraoch
411+         if  (q_seq_len >= 768 ) {
412+           sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
413+               output,
414+               q,
415+               k,
416+               v,
417+               dropout_p,
418+               is_causal,
419+               attn_mask,
420+               scale,
421+               nullopt , //  q_zero_points
422+               nullopt , //  q_scales
423+               nullopt , //  k_zero_points
424+               nullopt , //  k_scales
425+               nullopt , //  v_zero_points
426+               nullopt , //  v_scales
427+               is_seq_at_dim_1, /*  is_seq_at_dim_1 */ 
428+               start_pos,
429+               num_keys_for_causal_attention);
430+         } else  if  (q_seq_len >= 192 ) {
431+           sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
432+               output,
433+               q,
434+               k,
435+               v,
436+               dropout_p,
437+               is_causal,
438+               attn_mask,
439+               scale,
440+               nullopt , //  q_zero_points
441+               nullopt , //  q_scales
442+               nullopt , //  k_zero_points
443+               nullopt , //  k_scales
444+               nullopt , //  v_zero_points
445+               nullopt , //  v_scales
446+               is_seq_at_dim_1, /*  is_seq_at_dim_1 */ 
447+               start_pos,
448+               num_keys_for_causal_attention);
449+         } else  {
450+           sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
451+               output,
452+               q,
453+               k,
454+               v,
455+               dropout_p,
456+               is_causal,
457+               attn_mask,
458+               scale,
459+               nullopt , //  q_zero_points
460+               nullopt , //  q_scales
461+               nullopt , //  k_zero_points
462+               nullopt , //  k_scales
463+               nullopt , //  v_zero_points
464+               nullopt , //  v_scales
465+               is_seq_at_dim_1, /*  is_seq_at_dim_1 */ 
466+               start_pos,
467+               num_keys_for_causal_attention);
468+         }
469+       });
363470  return  output;
364471}
472+ 
473+ /* 
474+   Input params 
475+   @param[in] q_projected Projected query with query weights. 
476+   Format [n_layers, batch size, seq_len, num heads, head dim] 
477+   @param[in] k_projected Projected query with key weights. 
478+   Format [n_layers, batch size, seq_len, num heads, head dim] 
479+   @param[in] v_projected Projected query with value weights. 
480+   Format [n_layers, batch size, seq_len, num heads, head dim] 
481+   @param[in] key_cache Cache of previous k_projected. 
482+   Format [n_layers, batch size, max_seq_len, num heads, head dim] 
483+   @param[in] key_cache Cache of previous v_projected. 
484+   Format [n_layers, batch size, max_seq_len, num heads, head dim] 
485+   .... 
486+   @param[in] start_pos: sequence position 
487+ */ 
488+ Tensor& custom_sdpa_out (
489+     RuntimeContext& ctx,
490+     const  Tensor& q,
491+     const  Tensor& k,
492+     const  Tensor& v,
493+     const  int64_t  start_pos,
494+     const  optional<Tensor>& attn_mask,
495+     const  double  dropout_p,
496+     const  bool  is_causal,
497+     //  @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
498+     const  optional<double > scale,
499+     Tensor& output) {
500+   return  custom_sdpa_out_impl (
501+       ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
502+ }
365503/* 
366504  Input params 
367505  @param[in] q_projected Projected query with query weights. 
0 commit comments