Skip to content

Commit e8643c0

Browse files
committed
shader variants with/without unrolling
1 parent 4456649 commit e8643c0

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,9 +3094,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
30943094
uint32_t conv2d_BS_NPQ = 128;
30953095
uint32_t conv2d_TS_K = 8;
30963096
uint32_t conv2d_SHMEM_PAD = 4;
3097+
bool conv2d_UNROLL = true;
30973098

30983099
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
30993100
conv2d_SHMEM_PAD = 0;
3101+
conv2d_UNROLL = false;
31003102
}
31013103

31023104
switch (s) {
@@ -3141,14 +3143,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
31413143
}
31423144
}
31433145

3144-
ggml_vk_create_pipeline(
3145-
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3146-
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3147-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
3148-
ggml_vk_create_pipeline(
3149-
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3150-
sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3151-
{ conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
3146+
std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
3147+
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
3148+
3149+
if (conv2d_UNROLL) {
3150+
ggml_vk_create_pipeline(
3151+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3152+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3153+
ggml_vk_create_pipeline(
3154+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3155+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3156+
} else {
3157+
ggml_vk_create_pipeline(
3158+
device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3159+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3160+
ggml_vk_create_pipeline(
3161+
device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3162+
sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3163+
}
31523164
}
31533165

31543166
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void main() {
202202
Ash[B_ly * Ash_stride + B_lx] = val;
203203
}
204204
/* Load input to B_block: (BS_CRS x BS_NPQ) */
205-
[[unroll]] for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
205+
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
206206
uint32_t B_ly = r_offset + Br; /* Row index of B block */
207207
uint32_t B_lx = Bc;
208208
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
@@ -248,7 +248,7 @@ void main() {
248248
}
249249
barrier();
250250
if (T_y * TS_K < K) {
251-
[[unroll]] for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
251+
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
252252
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253253
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
254254
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,11 @@ void process_shaders() {
655655

656656
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
657657

658-
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
659-
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}});
658+
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
659+
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
660+
661+
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
662+
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
660663

661664
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
662665
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));

0 commit comments

Comments
 (0)