@@ -82,6 +82,51 @@ bool validate_flash_attention_args(
8282 return true ;
8383}
8484
85+ bool validate_cache_quant_params_args (
86+ const Tensor& t,
87+ const Tensor& t_zero_points,
88+ const Tensor& t_scales) {
89+ ET_CHECK_OR_RETURN_FALSE (
90+ t.dim () == t_scales.dim (),
91+ " Quantized tensor and scales must have the same number of dimensions" );
92+ ET_CHECK_OR_RETURN_FALSE (
93+ t.dim () == t_zero_points.dim (),
94+ " Quantized tensor and scales must have the same number of dimensions" );
95+
96+ ET_CHECK_OR_RETURN_FALSE (
97+ (t.scalar_type () == ScalarType::Char), " Tensor must be of int8_t type" );
98+
99+ ET_CHECK_OR_RETURN_FALSE (
100+ (t_scales.scalar_type () == ScalarType::Float),
101+ " Scales tensor must be of float type" );
102+
103+ ET_CHECK_OR_RETURN_FALSE (
104+ (t_zero_points.scalar_type () == ScalarType::Char),
105+ " Zero points tensor must be of int8_t type" );
106+
107+ // Sizes
108+ for (int64_t i = 0 ; i < t.dim () - 1 ; i++) {
109+ ET_CHECK_OR_RETURN_FALSE (
110+ (t.size (i) == t_scales.size (i)),
111+ " Quantized tensor and scales have different shape"
112+ " at dim: %" PRId64 " , t: %zd, t_scales: %zd" ,
113+ i,
114+ t.size (i),
115+ t_scales.size (i));
116+ ;
117+ ET_CHECK_OR_RETURN_FALSE (
118+ (t.size (i) == t_zero_points.size (i)),
119+ " Quantized tensor and zero points have different shape"
120+ " at dim: %" PRId64 " , t: %zd, t_scales: %zd" ,
121+ i,
122+ t.size (i),
123+ t_zero_points.size (i));
124+ ;
125+ }
126+
127+ return true ;
128+ }
129+
85130bool validate_cache_params (
86131 const Tensor& k_cache,
87132 const Tensor& v_cache,
@@ -233,7 +278,13 @@ Tensor& flash_attention_kernel_out(
233278 dropout_p,
234279 is_causal,
235280 attn_mask,
236- scale);
281+ scale,
282+ nullopt ,
283+ nullopt ,
284+ nullopt ,
285+ nullopt ,
286+ nullopt ,
287+ nullopt );
237288 } else if (q_seq_len >= 192 ) {
238289 sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
239290 output,
@@ -243,7 +294,13 @@ Tensor& flash_attention_kernel_out(
243294 dropout_p,
244295 is_causal,
245296 attn_mask,
246- scale);
297+ scale,
298+ nullopt ,
299+ nullopt ,
300+ nullopt ,
301+ nullopt ,
302+ nullopt ,
303+ nullopt );
247304 } else {
248305 sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
249306 output,
@@ -253,28 +310,19 @@ Tensor& flash_attention_kernel_out(
253310 dropout_p,
254311 is_causal,
255312 attn_mask,
256- scale);
313+ scale,
314+ nullopt ,
315+ nullopt ,
316+ nullopt ,
317+ nullopt ,
318+ nullopt ,
319+ nullopt );
257320 }
258321 });
259322 return output;
260323}
261324
262- /*
263- Input params
264- @param[in] q_projected Projected query with query weights.
265- Format [n_layers, batch size, seq_len, num heads, head dim]
266- @param[in] k_projected Projected query with key weights.
267- Format [n_layers, batch size, seq_len, num heads, head dim]
268- @param[in] v_projected Projected query with value weights.
269- Format [n_layers, batch size, seq_len, num heads, head dim]
270- @param[in] key_cache Cache of previous k_projected.
271- Format [n_layers, batch size, max_seq_len, num heads, head dim]
272- @param[in] key_cache Cache of previous v_projected.
273- Format [n_layers, batch size, max_seq_len, num heads, head dim]
274- ....
275- @param[in] start_pos: sequence position
276- */
277- Tensor& custom_sdpa_out (
325+ Tensor& custom_sdpa_out_impl (
278326 RuntimeContext& ctx,
279327 const Tensor& q,
280328 const Tensor& k,
@@ -285,7 +333,13 @@ Tensor& custom_sdpa_out(
285333 const bool is_causal,
286334 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
287335 const optional<double > scale,
288- Tensor& output) {
336+ Tensor& output,
337+ const optional<Tensor>& q_zero_points = nullopt ,
338+ const optional<Tensor>& q_scales = nullopt ,
339+ const optional<Tensor>& k_zero_points = nullopt ,
340+ const optional<Tensor>& k_scales = nullopt ,
341+ const optional<Tensor>& v_zero_points = nullopt ,
342+ const optional<Tensor>& v_scales = nullopt ) {
289343 ET_KERNEL_CHECK_MSG (
290344 ctx,
291345 !attn_mask.has_value () || !is_causal,
@@ -300,6 +354,40 @@ Tensor& custom_sdpa_out(
300354 output,
301355 " Invalid arguments" );
302356
357+ bool is_seq_at_dim_1{true };
358+ if (q.scalar_type () == ScalarType::Char) {
359+ is_seq_at_dim_1 = false ;
360+ ET_KERNEL_CHECK_MSG (
361+ ctx,
362+ q_scales.has_value () && q_zero_points.has_value () &&
363+ k_scales.has_value () && k_zero_points.has_value () &&
364+ q_scales.has_value () && q_zero_points.has_value (),
365+ InvalidArgument,
366+ output,
367+ " If q is quantized, k and v must be quantized as well" );
368+ ET_KERNEL_CHECK_MSG (
369+ ctx,
370+ validate_cache_quant_params_args (
371+ q, q_zero_points.value (), q_scales.value ()),
372+ InvalidArgument,
373+ output,
374+ " Invalid arguments for quantized query" );
375+ ET_KERNEL_CHECK_MSG (
376+ ctx,
377+ validate_cache_quant_params_args (
378+ k, k_zero_points.value (), k_scales.value ()),
379+ InvalidArgument,
380+ output,
381+ " Invalid arguments for quantized key" );
382+ ET_KERNEL_CHECK_MSG (
383+ ctx,
384+ validate_cache_quant_params_args (
385+ v, v_zero_points.value (), v_scales.value ()),
386+ InvalidArgument,
387+ output,
388+ " Invalid arguments for quantized value" );
389+ }
390+
303391 ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
304392
305393 const int64_t seq_len = q.size (1 );
@@ -315,53 +403,103 @@ Tensor& custom_sdpa_out(
315403
316404 // TODO(task): replace the template param selection logic
317405 // with whatever apprpriately makes more sense for
318- ET_SWITCH_FLOAT_TYPES (q.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
319- // TODO we need to re-evaluate this for ARM CPUs
320- // And there can be many so instead of templatizing
321- // we might consider another appraoch
322- if (q_seq_len >= 768 ) {
323- sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
324- output,
325- q,
326- k,
327- v,
328- dropout_p,
329- is_causal,
330- attn_mask,
331- scale,
332- true , /* is_seq_at_dim_1 */
333- start_pos,
334- num_keys_for_causal_attention);
335- } else if (q_seq_len >= 192 ) {
336- sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
337- output,
338- q,
339- k,
340- v,
341- dropout_p,
342- is_causal,
343- attn_mask,
344- scale,
345- true , /* is_seq_at_dim_1 */
346- start_pos,
347- num_keys_for_causal_attention);
348- } else {
349- sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
350- output,
351- q,
352- k,
353- v,
354- dropout_p,
355- is_causal,
356- attn_mask,
357- scale,
358- true , /* is_seq_at_dim_1 */
359- start_pos,
360- num_keys_for_causal_attention);
361- }
362- });
406+ ET_SWITCH_FLOAT_TYPES (
407+ output.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
408+ // TODO we need to re-evaluate this for ARM CPUs
409+ // And there can be many so instead of templatizing
410+ // we might consider another appraoch
411+ if (q_seq_len >= 768 ) {
412+ sdpa::impl::cpu_flash_attention<CTYPE, 256 , 512 >(
413+ output,
414+ q,
415+ k,
416+ v,
417+ dropout_p,
418+ is_causal,
419+ attn_mask,
420+ scale,
421+ nullopt , // q_zero_points
422+ nullopt , // q_scales
423+ nullopt , // k_zero_points
424+ nullopt , // k_scales
425+ nullopt , // v_zero_points
426+ nullopt , // v_scales
427+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
428+ start_pos,
429+ num_keys_for_causal_attention);
430+ } else if (q_seq_len >= 192 ) {
431+ sdpa::impl::cpu_flash_attention<CTYPE, 64 , 512 >(
432+ output,
433+ q,
434+ k,
435+ v,
436+ dropout_p,
437+ is_causal,
438+ attn_mask,
439+ scale,
440+ nullopt , // q_zero_points
441+ nullopt , // q_scales
442+ nullopt , // k_zero_points
443+ nullopt , // k_scales
444+ nullopt , // v_zero_points
445+ nullopt , // v_scales
446+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
447+ start_pos,
448+ num_keys_for_causal_attention);
449+ } else {
450+ sdpa::impl::cpu_flash_attention<CTYPE, 32 , 512 >(
451+ output,
452+ q,
453+ k,
454+ v,
455+ dropout_p,
456+ is_causal,
457+ attn_mask,
458+ scale,
459+ nullopt , // q_zero_points
460+ nullopt , // q_scales
461+ nullopt , // k_zero_points
462+ nullopt , // k_scales
463+ nullopt , // v_zero_points
464+ nullopt , // v_scales
465+ is_seq_at_dim_1, /* is_seq_at_dim_1 */
466+ start_pos,
467+ num_keys_for_causal_attention);
468+ }
469+ });
363470 return output;
364471}
472+
473+ /*
474+ Input params
475+ @param[in] q_projected Projected query with query weights.
476+ Format [n_layers, batch size, seq_len, num heads, head dim]
477+ @param[in] k_projected Projected query with key weights.
478+ Format [n_layers, batch size, seq_len, num heads, head dim]
479+ @param[in] v_projected Projected query with value weights.
480+ Format [n_layers, batch size, seq_len, num heads, head dim]
481+ @param[in] key_cache Cache of previous k_projected.
482+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
483+ @param[in] key_cache Cache of previous v_projected.
484+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
485+ ....
486+ @param[in] start_pos: sequence position
487+ */
488+ Tensor& custom_sdpa_out (
489+ RuntimeContext& ctx,
490+ const Tensor& q,
491+ const Tensor& k,
492+ const Tensor& v,
493+ const int64_t start_pos,
494+ const optional<Tensor>& attn_mask,
495+ const double dropout_p,
496+ const bool is_causal,
497+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
498+ const optional<double > scale,
499+ Tensor& output) {
500+ return custom_sdpa_out_impl (
501+ ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output);
502+ }
365503/*
366504 Input params
367505 @param[in] q_projected Projected query with query weights.
0 commit comments