Skip to content

Commit 136ecfb

Browse files
committed
Three tiles sizes for CONV_2D, and a heuristic to choose
1 parent 9c12ef7 commit 136ecfb

File tree

1 file changed

+103
-72
lines changed

1 file changed

+103
-72
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 103 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,13 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
319319
return vk_device_architecture::OTHER;
320320
}
321321

322+
enum vk_conv_shapes {
323+
CONV_SHAPE_128x128,
324+
CONV_SHAPE_64x32,
325+
CONV_SHAPE_32x256,
326+
CONV_SHAPE_COUNT,
327+
};
328+
322329
struct vk_device_struct {
323330
std::recursive_mutex mutex;
324331

@@ -483,8 +490,8 @@ struct vk_device_struct {
483490
vk_pipeline pipeline_rwkv_wkv6_f32;
484491
vk_pipeline pipeline_rwkv_wkv7_f32;
485492
vk_pipeline pipeline_opt_step_adamw_f32;
486-
vk_pipeline pipeline_conv2d_f32;
487-
vk_pipeline pipeline_conv2d_f16_f32;
493+
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
494+
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
488495
vk_pipeline pipeline_conv2d_dw_whcn_f32;
489496
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
490497

@@ -3062,59 +3069,61 @@ static void ggml_vk_load_shaders(vk_device& device) {
30623069
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
30633070

30643071
// conv2d
3065-
uint32_t conv2d_WG_SIZE = 256;
3066-
uint32_t conv2d_BS_K = 128;
3067-
uint32_t conv2d_BS_CRS = 16;
3068-
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3069-
uint32_t conv2d_BS_NPQ = 128;
3070-
uint32_t conv2d_TS_K = 8;
3071-
uint32_t conv2d_SHMEM_PAD = 4;
3072-
3073-
if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
3074-
conv2d_BS_K = 64;
3075-
conv2d_BS_CRS = 32;
3076-
conv2d_BS_NPQ = 32;
3077-
conv2d_TS_K = 4;
3078-
}
3072+
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
3073+
uint32_t conv2d_WG_SIZE = 256;
3074+
uint32_t conv2d_BS_K = 128;
3075+
uint32_t conv2d_BS_CRS = 16;
3076+
uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
3077+
uint32_t conv2d_BS_NPQ = 128;
3078+
uint32_t conv2d_TS_K = 8;
3079+
uint32_t conv2d_SHMEM_PAD = 4;
3080+
3081+
switch (s) {
3082+
default:
3083+
case CONV_SHAPE_128x128:
3084+
conv2d_BS_K = 128;
3085+
conv2d_BS_NPQ = 128;
3086+
conv2d_BS_CRS = 16;
3087+
break;
3088+
case CONV_SHAPE_64x32:
3089+
conv2d_BS_K = 64;
3090+
conv2d_BS_NPQ = 32;
3091+
conv2d_BS_CRS = 32;
3092+
conv2d_TS_K = 4;
3093+
break;
3094+
case CONV_SHAPE_32x256:
3095+
conv2d_BS_K = 32;
3096+
conv2d_BS_NPQ = 256;
3097+
conv2d_BS_CRS = 16;
3098+
break;
3099+
}
30793100

3080-
if (device->subgroup_shuffle &&
3081-
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3082-
device->vendor_id != VK_VENDOR_ID_NVIDIA) { // Collectives no faster on NVIDIA.
3083-
use_collectives = 1;
3084-
conv2d_BS_CRS = std::min(
3085-
device->subgroup_size,
3086-
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
3087-
}
3101+
if (device->subgroup_shuffle &&
3102+
device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
3103+
device->vendor_id != VK_VENDOR_ID_NVIDIA) { // Collectives no faster on NVIDIA.
3104+
use_collectives = 1;
3105+
conv2d_BS_CRS = std::min(
3106+
device->subgroup_size,
3107+
conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
3108+
}
30883109

3089-
uint32_t conv2d_shmem_req =
3090-
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
3091-
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3092-
conv2d_BS_CRS = 8;
3093-
if (use_collectives) {
3094-
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3110+
uint32_t conv2d_shmem_req =
3111+
(conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
3112+
if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
3113+
conv2d_BS_CRS = 8;
3114+
if (use_collectives) {
3115+
conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS);
3116+
}
30953117
}
3096-
}
30973118

3098-
if (use_collectives) {
30993119
ggml_vk_create_pipeline(
3100-
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3120+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
31013121
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3102-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3122+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, use_collectives);
31033123
ggml_vk_create_pipeline(
3104-
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3124+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
31053125
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3106-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true);
3107-
} else {
3108-
ggml_vk_create_pipeline(
3109-
device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3110-
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3111-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3112-
false);
3113-
ggml_vk_create_pipeline(
3114-
device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3115-
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3116-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true,
3117-
false);
3126+
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, use_collectives);
31183127
}
31193128

31203129
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
@@ -6666,6 +6675,34 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
66666675
}
66676676
}
66686677

6678+
static std::array<uint32_t, 3> ggml_vk_get_conv_elements(const ggml_tensor *dst) {
6679+
const ggml_tensor *src0 = dst->src[0];
6680+
const ggml_tensor *src1 = dst->src[1];
6681+
6682+
// src0 - kernel: [KW, KH, Cin, Cout]
6683+
// src1 - input: [W, H, Cin, N]
6684+
// dst - result: [OW, OH, Cout, N]
6685+
6686+
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
6687+
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
6688+
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
6689+
};
6690+
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
6691+
int64_t W = src1->ne[0];
6692+
int64_t H = src1->ne[1];
6693+
int64_t KW = src0->ne[0];
6694+
int64_t KH = src0->ne[1];
6695+
int64_t Cout = src0->ne[3];
6696+
int64_t N = src1->ne[3];
6697+
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
6698+
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
6699+
int64_t NPQ = N * OW * OH;
6700+
6701+
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
6702+
std::array<uint32_t, 3> elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
6703+
return elements;
6704+
}
6705+
66696706
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
66706707
switch (op) {
66716708
case GGML_OP_GET_ROWS:
@@ -6995,10 +7032,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
69957032
case GGML_OP_CONV_2D:
69967033
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
69977034
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
7035+
auto elements = ggml_vk_get_conv_elements(dst);
7036+
vk_conv_shapes shape;
7037+
7038+
uint32_t tiles[CONV_SHAPE_COUNT];
7039+
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
7040+
tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
7041+
}
7042+
if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= ctx->device->shader_core_count * 2) {
7043+
shape = CONV_SHAPE_128x128;
7044+
} else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= ctx->device->shader_core_count * 2) {
7045+
shape = CONV_SHAPE_32x256;
7046+
} else {
7047+
shape = CONV_SHAPE_64x32;
7048+
}
7049+
69987050
if (src0->type == GGML_TYPE_F32) {
6999-
return ctx->device->pipeline_conv2d_f32;
7051+
return ctx->device->pipeline_conv2d_f32[shape];
70007052
} else if (src0->type == GGML_TYPE_F16) {
7001-
return ctx->device->pipeline_conv2d_f16_f32;
7053+
return ctx->device->pipeline_conv2d_f16_f32[shape];
70027054
}
70037055
}
70047056
return nullptr;
@@ -7326,29 +7378,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
73267378
} break;
73277379
case GGML_OP_CONV_2D:
73287380
{
7329-
// src0 - kernel: [KW, KH, Cin, Cout]
7330-
// src1 - input: [W, H, Cin, N]
7331-
// dst - result: [OW, OH, Cout, N]
7332-
7333-
// Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
7334-
auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
7335-
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
7336-
};
7337-
// parallelize in {OW/BS_K, OH/BS_NPQ, 1}
7338-
int64_t W = src1->ne[0];
7339-
int64_t H = src1->ne[1];
7340-
int64_t KW = src0->ne[0];
7341-
int64_t KH = src0->ne[1];
7342-
int64_t Cout = src0->ne[3];
7343-
int64_t N = src1->ne[3];
7344-
int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]);
7345-
int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]);
7346-
int64_t NPQ = N * OW * OH;
7347-
7348-
// Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups
7349-
elements = { static_cast<uint32_t>(Cout), static_cast<uint32_t>(NPQ), 1 };
7350-
}
7351-
break;
7381+
elements = ggml_vk_get_conv_elements(dst);
7382+
} break;
73527383
case GGML_OP_ADD:
73537384
case GGML_OP_SUB:
73547385
case GGML_OP_DIV:

0 commit comments

Comments
 (0)