@@ -3383,6 +3383,131 @@ TEST_CASE("scaled dot product attention with quantized NA mps")
33833383 }
33843384}
33853385
3386+ TEST_CASE ("scaled dot product attention with quantized NA mps batched" )
3387+ {
3388+ GUARD_ELSE_RETURN (ccv_nnc_cmd_ok (CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD , CCV_NNC_BACKEND_MPS ));
3389+ const int B = 3 ;
3390+ const int R = 128 ;
3391+ const int C = 128 ;
3392+ const int H = 8 ;
3393+ const int Ds [] = { 64 , 128 };
3394+ const int datatypes [] = { CCV_16F , CCV_16BF , CCV_32F };
3395+ const float tolerances [] = { 2e-2 , 3e-2 , 2e-2 };
3396+ const char * datatype_names [] = { "16F" , "16BF" , "32F" };
3397+ for (int d_idx = 0 ; d_idx < (int )(sizeof (Ds ) / sizeof (Ds [0 ])); ++ d_idx )
3398+ {
3399+ const int D = Ds [d_idx ];
3400+ const float scale = 1.0 / sqrt ((float )D );
3401+
3402+ ccv_nnc_tensor_t * const q_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (32F , B , R , H , D ), 0 );
3403+ ccv_nnc_tensor_t * const k_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (32F , B , C , H , D ), 0 );
3404+ ccv_nnc_tensor_t * const v_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (32F , B , C , H , D ), 0 );
3405+ ccv_nnc_tensor_t * const q_tensor_f16 = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (16F , B , R , H , D ), 0 );
3406+ ccv_nnc_tensor_t * const k_tensor_f16 = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (16F , B , C , H , D ), 0 );
3407+ ccv_nnc_tensor_t * const v_tensor_f16 = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (16F , B , C , H , D ), 0 );
3408+ const int q_count = B * R * H * D ;
3409+ const int kv_count = B * C * H * D ;
3410+ dsfmt_t dsfmt ;
3411+ dsfmt_init_gen_rand (& dsfmt , 101 + d_idx );
3412+ for (int i = 0 ; i < q_count ; ++ i )
3413+ q_tensor -> data .f32 [i ] = dsfmt_genrand_open_close (& dsfmt ) - 0.5 ;
3414+ for (int i = 0 ; i < kv_count ; ++ i )
3415+ k_tensor -> data .f32 [i ] = dsfmt_genrand_open_close (& dsfmt ) - 0.5 ;
3416+ for (int i = 0 ; i < kv_count ; ++ i )
3417+ v_tensor -> data .f32 [i ] = dsfmt_genrand_open_close (& dsfmt ) - 0.5 ;
3418+
3419+ ccv_nnc_tensor_t * const o_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (32F , B , R , H , D ), 0 );
3420+ ccv_nnc_cmd_t cpu_cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD (scale , 0 );
3421+ ccv_nnc_cmd_exec (cpu_cmd , ccv_nnc_no_hint , 0 , TENSOR_LIST (q_tensor , k_tensor , v_tensor ), TENSOR_LIST (o_tensor ), 0 );
3422+
3423+ for (int datatype_idx = 0 ; datatype_idx < 3 ; ++ datatype_idx )
3424+ {
3425+ const int datatype = datatypes [datatype_idx ];
3426+ ccv_nnc_tensor_t * q_input = q_tensor ;
3427+ ccv_nnc_tensor_t * k_input = k_tensor ;
3428+ ccv_nnc_tensor_t * v_input = v_tensor ;
3429+ ccv_nnc_tensor_t * copy_of_gpu_o_tensor = 0 ;
3430+ ccv_nnc_tensor_t * gpu_q_tensor = 0 ;
3431+ ccv_nnc_tensor_t * gpu_k_tensor = 0 ;
3432+ ccv_nnc_tensor_t * gpu_v_tensor = 0 ;
3433+ ccv_nnc_tensor_t * gpu_o_tensor = 0 ;
3434+ if (datatype == CCV_16F )
3435+ {
3436+ ccv_float_to_half_precision (q_tensor -> data .f32 , (uint16_t * )q_tensor_f16 -> data .f16 , q_count );
3437+ ccv_float_to_half_precision (k_tensor -> data .f32 , (uint16_t * )k_tensor_f16 -> data .f16 , kv_count );
3438+ ccv_float_to_half_precision (v_tensor -> data .f32 , (uint16_t * )v_tensor_f16 -> data .f16 , kv_count );
3439+ q_input = q_tensor_f16 ;
3440+ k_input = k_tensor_f16 ;
3441+ v_input = v_tensor_f16 ;
3442+ gpu_q_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16F , B , R , H , D ), 0 );
3443+ gpu_k_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16F , B , C , H , D ), 0 );
3444+ gpu_v_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16F , B , C , H , D ), 0 );
3445+ gpu_o_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16F , B , R , H , D ), 0 );
3446+ copy_of_gpu_o_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (16F , B , R , H , D ), 0 );
3447+ } else if (datatype == CCV_16BF ) {
3448+ ccv_float_to_bfloat (q_tensor -> data .f32 , (uint16_t * )q_tensor_f16 -> data .f16 , q_count );
3449+ ccv_float_to_bfloat (k_tensor -> data .f32 , (uint16_t * )k_tensor_f16 -> data .f16 , kv_count );
3450+ ccv_float_to_bfloat (v_tensor -> data .f32 , (uint16_t * )v_tensor_f16 -> data .f16 , kv_count );
3451+ q_input = q_tensor_f16 ;
3452+ k_input = k_tensor_f16 ;
3453+ v_input = v_tensor_f16 ;
3454+ gpu_q_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16BF , B , R , H , D ), 0 );
3455+ gpu_k_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16BF , B , C , H , D ), 0 );
3456+ gpu_v_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16BF , B , C , H , D ), 0 );
3457+ gpu_o_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 16BF , B , R , H , D ), 0 );
3458+ copy_of_gpu_o_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (16BF , B , R , H , D ), 0 );
3459+ } else {
3460+ gpu_q_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 32F , B , R , H , D ), 0 );
3461+ gpu_k_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 32F , B , C , H , D ), 0 );
3462+ gpu_v_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 32F , B , C , H , D ), 0 );
3463+ gpu_o_tensor = ccv_nnc_tensor_new (0 , GPU_TENSOR_NHWC (000 , 32F , B , R , H , D ), 0 );
3464+ copy_of_gpu_o_tensor = ccv_nnc_tensor_new (0 , CPU_TENSOR_NHWC (32F , B , R , H , D ), 0 );
3465+ }
3466+ ccv_nnc_cmd_exec (CMD_DATA_TRANSFER_FORWARD (), ccv_nnc_no_hint , 0 , TENSOR_LIST (q_input , k_input , v_input ), TENSOR_LIST (gpu_q_tensor , gpu_k_tensor , gpu_v_tensor ), 0 );
3467+ ccv_nnc_cmd_t gpu_cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD (scale , 0 );
3468+ gpu_cmd .info .scaled_dot_product_attention .flags = CCV_NNC_GEMM_16F | CCV_NNC_GEMM_8I ;
3469+ ccv_nnc_cmd_exec (gpu_cmd , ccv_nnc_no_hint , 0 , TENSOR_LIST (gpu_q_tensor , gpu_k_tensor , gpu_v_tensor ), TENSOR_LIST (gpu_o_tensor ), 0 );
3470+ ccv_nnc_cmd_exec (CMD_DATA_TRANSFER_FORWARD (), ccv_nnc_no_hint , 0 , TENSOR_LIST (gpu_o_tensor ), TENSOR_LIST (copy_of_gpu_o_tensor ), 0 );
3471+
3472+ const int count = B * R * H * D ;
3473+ float * const cpu_f32 = (float * )ccmalloc (sizeof (float ) * count );
3474+ float * const gpu_f32 = (float * )ccmalloc (sizeof (float ) * count );
3475+ memcpy (cpu_f32 , o_tensor -> data .f32 , sizeof (float ) * count );
3476+ if (datatype == CCV_16F )
3477+ ccv_half_precision_to_float ((uint16_t * )copy_of_gpu_o_tensor -> data .f16 , gpu_f32 , count );
3478+ else if (datatype == CCV_16BF )
3479+ ccv_bfloat_to_float ((uint16_t * )copy_of_gpu_o_tensor -> data .f16 , gpu_f32 , count );
3480+ else
3481+ memcpy (gpu_f32 , copy_of_gpu_o_tensor -> data .f32 , sizeof (float ) * count );
3482+ float max_relative_diff = 0 ;
3483+ int max_diff_idx = 0 ;
3484+ for (int i = 0 ; i < count ; ++ i )
3485+ {
3486+ const float denom = fmaxf (fmaxf (fabsf (cpu_f32 [i ]), fabsf (gpu_f32 [i ])), 1.0f );
3487+ const float relative_diff = fabsf (cpu_f32 [i ] - gpu_f32 [i ]) / denom ;
3488+ if (relative_diff > max_relative_diff )
3489+ max_relative_diff = relative_diff , max_diff_idx = i ;
3490+ }
3491+ REQUIRE (max_relative_diff <= tolerances [datatype_idx ], "quantized batched attention result should match CPU reference for dtype=%s D=%d (max relative diff %g at %d: %g vs %g)" , datatype_names [datatype_idx ], D , max_relative_diff , max_diff_idx , cpu_f32 [max_diff_idx ], gpu_f32 [max_diff_idx ]);
3492+
3493+ ccfree (cpu_f32 );
3494+ ccfree (gpu_f32 );
3495+ ccv_nnc_tensor_free (gpu_o_tensor );
3496+ ccv_nnc_tensor_free (copy_of_gpu_o_tensor );
3497+ ccv_nnc_tensor_free (gpu_q_tensor );
3498+ ccv_nnc_tensor_free (gpu_k_tensor );
3499+ ccv_nnc_tensor_free (gpu_v_tensor );
3500+ }
3501+ ccv_nnc_tensor_free (o_tensor );
3502+ ccv_nnc_tensor_free (q_tensor );
3503+ ccv_nnc_tensor_free (k_tensor );
3504+ ccv_nnc_tensor_free (v_tensor );
3505+ ccv_nnc_tensor_free (q_tensor_f16 );
3506+ ccv_nnc_tensor_free (k_tensor_f16 );
3507+ ccv_nnc_tensor_free (v_tensor_f16 );
3508+ }
3509+ }
3510+
33863511TEST_CASE ("scaled dot product attention with quantized NA mps for non-multiple-of-64 sequence" )
33873512{
33883513 GUARD_ELSE_RETURN (ccv_nnc_cmd_ok (CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD , CCV_NNC_BACKEND_MPS ));
0 commit comments