Skip to content

Commit 730c29b

Browse files
committed
Add batch attention.
1 parent 8946be7 commit 730c29b

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

test/int/nnc/mpsblas.tests.c

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,131 @@ TEST_CASE("scaled dot product attention with quantized NA mps")
33833383
}
33843384
}
33853385

3386+
TEST_CASE("scaled dot product attention with quantized NA mps batched")
3387+
{
3388+
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_MPS));
3389+
const int B = 3;
3390+
const int R = 128;
3391+
const int C = 128;
3392+
const int H = 8;
3393+
const int Ds[] = { 64, 128 };
3394+
const int datatypes[] = { CCV_16F, CCV_16BF, CCV_32F };
3395+
const float tolerances[] = { 2e-2, 3e-2, 2e-2 };
3396+
const char* datatype_names[] = { "16F", "16BF", "32F" };
3397+
for (int d_idx = 0; d_idx < (int)(sizeof(Ds) / sizeof(Ds[0])); ++d_idx)
3398+
{
3399+
const int D = Ds[d_idx];
3400+
const float scale = 1.0 / sqrt((float)D);
3401+
3402+
ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
3403+
ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0);
3404+
ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0);
3405+
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0);
3406+
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0);
3407+
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0);
3408+
const int q_count = B * R * H * D;
3409+
const int kv_count = B * C * H * D;
3410+
dsfmt_t dsfmt;
3411+
dsfmt_init_gen_rand(&dsfmt, 101 + d_idx);
3412+
for (int i = 0; i < q_count; ++i)
3413+
q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) - 0.5;
3414+
for (int i = 0; i < kv_count; ++i)
3415+
k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) - 0.5;
3416+
for (int i = 0; i < kv_count; ++i)
3417+
v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) - 0.5;
3418+
3419+
ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
3420+
ccv_nnc_cmd_t cpu_cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0);
3421+
ccv_nnc_cmd_exec(cpu_cmd, ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(o_tensor), 0);
3422+
3423+
for (int datatype_idx = 0; datatype_idx < 3; ++datatype_idx)
3424+
{
3425+
const int datatype = datatypes[datatype_idx];
3426+
ccv_nnc_tensor_t* q_input = q_tensor;
3427+
ccv_nnc_tensor_t* k_input = k_tensor;
3428+
ccv_nnc_tensor_t* v_input = v_tensor;
3429+
ccv_nnc_tensor_t* copy_of_gpu_o_tensor = 0;
3430+
ccv_nnc_tensor_t* gpu_q_tensor = 0;
3431+
ccv_nnc_tensor_t* gpu_k_tensor = 0;
3432+
ccv_nnc_tensor_t* gpu_v_tensor = 0;
3433+
ccv_nnc_tensor_t* gpu_o_tensor = 0;
3434+
if (datatype == CCV_16F)
3435+
{
3436+
ccv_float_to_half_precision(q_tensor->data.f32, (uint16_t*)q_tensor_f16->data.f16, q_count);
3437+
ccv_float_to_half_precision(k_tensor->data.f32, (uint16_t*)k_tensor_f16->data.f16, kv_count);
3438+
ccv_float_to_half_precision(v_tensor->data.f32, (uint16_t*)v_tensor_f16->data.f16, kv_count);
3439+
q_input = q_tensor_f16;
3440+
k_input = k_tensor_f16;
3441+
v_input = v_tensor_f16;
3442+
gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0);
3443+
gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0);
3444+
gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0);
3445+
gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0);
3446+
copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0);
3447+
} else if (datatype == CCV_16BF) {
3448+
ccv_float_to_bfloat(q_tensor->data.f32, (uint16_t*)q_tensor_f16->data.f16, q_count);
3449+
ccv_float_to_bfloat(k_tensor->data.f32, (uint16_t*)k_tensor_f16->data.f16, kv_count);
3450+
ccv_float_to_bfloat(v_tensor->data.f32, (uint16_t*)v_tensor_f16->data.f16, kv_count);
3451+
q_input = q_tensor_f16;
3452+
k_input = k_tensor_f16;
3453+
v_input = v_tensor_f16;
3454+
gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16BF, B, R, H, D), 0);
3455+
gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16BF, B, C, H, D), 0);
3456+
gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16BF, B, C, H, D), 0);
3457+
gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16BF, B, R, H, D), 0);
3458+
copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16BF, B, R, H, D), 0);
3459+
} else {
3460+
gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, R, H, D), 0);
3461+
gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, C, H, D), 0);
3462+
gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, C, H, D), 0);
3463+
gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, R, H, D), 0);
3464+
copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
3465+
}
3466+
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_input, k_input, v_input), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0);
3467+
ccv_nnc_cmd_t gpu_cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0);
3468+
gpu_cmd.info.scaled_dot_product_attention.flags = CCV_NNC_GEMM_16F | CCV_NNC_GEMM_8I;
3469+
ccv_nnc_cmd_exec(gpu_cmd, ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), TENSOR_LIST(gpu_o_tensor), 0);
3470+
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor), 0);
3471+
3472+
const int count = B * R * H * D;
3473+
float* const cpu_f32 = (float*)ccmalloc(sizeof(float) * count);
3474+
float* const gpu_f32 = (float*)ccmalloc(sizeof(float) * count);
3475+
memcpy(cpu_f32, o_tensor->data.f32, sizeof(float) * count);
3476+
if (datatype == CCV_16F)
3477+
ccv_half_precision_to_float((uint16_t*)copy_of_gpu_o_tensor->data.f16, gpu_f32, count);
3478+
else if (datatype == CCV_16BF)
3479+
ccv_bfloat_to_float((uint16_t*)copy_of_gpu_o_tensor->data.f16, gpu_f32, count);
3480+
else
3481+
memcpy(gpu_f32, copy_of_gpu_o_tensor->data.f32, sizeof(float) * count);
3482+
float max_relative_diff = 0;
3483+
int max_diff_idx = 0;
3484+
for (int i = 0; i < count; ++i)
3485+
{
3486+
const float denom = fmaxf(fmaxf(fabsf(cpu_f32[i]), fabsf(gpu_f32[i])), 1.0f);
3487+
const float relative_diff = fabsf(cpu_f32[i] - gpu_f32[i]) / denom;
3488+
if (relative_diff > max_relative_diff)
3489+
max_relative_diff = relative_diff, max_diff_idx = i;
3490+
}
3491+
REQUIRE(max_relative_diff <= tolerances[datatype_idx], "quantized batched attention result should match CPU reference for dtype=%s D=%d (max relative diff %g at %d: %g vs %g)", datatype_names[datatype_idx], D, max_relative_diff, max_diff_idx, cpu_f32[max_diff_idx], gpu_f32[max_diff_idx]);
3492+
3493+
ccfree(cpu_f32);
3494+
ccfree(gpu_f32);
3495+
ccv_nnc_tensor_free(gpu_o_tensor);
3496+
ccv_nnc_tensor_free(copy_of_gpu_o_tensor);
3497+
ccv_nnc_tensor_free(gpu_q_tensor);
3498+
ccv_nnc_tensor_free(gpu_k_tensor);
3499+
ccv_nnc_tensor_free(gpu_v_tensor);
3500+
}
3501+
ccv_nnc_tensor_free(o_tensor);
3502+
ccv_nnc_tensor_free(q_tensor);
3503+
ccv_nnc_tensor_free(k_tensor);
3504+
ccv_nnc_tensor_free(v_tensor);
3505+
ccv_nnc_tensor_free(q_tensor_f16);
3506+
ccv_nnc_tensor_free(k_tensor_f16);
3507+
ccv_nnc_tensor_free(v_tensor_f16);
3508+
}
3509+
}
3510+
33863511
TEST_CASE("scaled dot product attention with quantized NA mps for non-multiple-of-64 sequence")
33873512
{
33883513
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_MPS));

0 commit comments

Comments
 (0)