@@ -3094,9 +3094,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
30943094 uint32_t conv2d_BS_NPQ = 128;
30953095 uint32_t conv2d_TS_K = 8;
30963096 uint32_t conv2d_SHMEM_PAD = 4;
3097+ bool conv2d_UNROLL = true;
30973098
30983099 if (device->vendor_id == VK_VENDOR_ID_INTEL) {
30993100 conv2d_SHMEM_PAD = 0;
3101+ conv2d_UNROLL = false;
31003102 }
31013103
31023104 switch (s) {
@@ -3141,14 +3143,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
31413143 }
31423144 }
31433145
3144- ggml_vk_create_pipeline(
3145- device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3146- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3147- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
3148- ggml_vk_create_pipeline(
3149- device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3150- sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 },
3151- { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }, 1, true, use_collectives);
3146+ std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
3147+ std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
3148+
3149+ if (conv2d_UNROLL) {
3150+ ggml_vk_create_pipeline(
3151+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
3152+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3153+ ggml_vk_create_pipeline(
3154+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3,
3155+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3156+ } else {
3157+ ggml_vk_create_pipeline(
3158+ device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3,
3159+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3160+ ggml_vk_create_pipeline(
3161+ device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3,
3162+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3163+ }
31523164 }
31533165
31543166 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
0 commit comments