@@ -367,6 +367,12 @@ enum vk_conv_shapes {
367367 CONV_SHAPE_COUNT,
368368};
369369
370+ uint32_t conv_shapes_wg_denoms[][3] = {
371+ { 128, 128, 1 },
372+ { 64, 32, 1 },
373+ { 32, 256, 1 },
374+ };
375+
370376enum dmmv_wg_sizes {
371377 DMMV_WG_SIZE_SUBGROUP,
372378 DMMV_WG_SIZE_LARGE,
@@ -395,6 +401,18 @@ struct vk_fa_pipeline_state {
395401 }
396402};
397403
404+ struct vk_conv2d_pipeline_state {
405+ vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
406+ : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
407+
408+ uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
409+
410+ bool operator<(const vk_conv2d_pipeline_state &b) const {
411+ return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
412+ std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
413+ }
414+ };
415+
398416enum shader_reduction_mode {
399417 SHADER_REDUCTION_MODE_SHMEM,
400418 SHADER_REDUCTION_MODE_HYBRID,
@@ -691,10 +709,10 @@ struct vk_device_struct {
691709 vk_pipeline pipeline_ssm_conv_f32;
692710 vk_pipeline pipeline_opt_step_adamw_f32;
693711 vk_pipeline pipeline_opt_step_sgd_f32;
694- vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
695- vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
696- vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
697- vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
712+ std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
713+ std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
714+ std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
715+ std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
698716 vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
699717 vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
700718
@@ -1274,17 +1292,13 @@ struct vk_op_conv2d_push_constants {
12741292 uint32_t nb2;
12751293 uint32_t nb3;
12761294
1277- // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
1278- uint32_t KWmp; uint32_t KWL;
1279- uint32_t KWKHmp; uint32_t KWKHL;
1295+ // init_fastdiv_values constants for dividing by OW, OW*OH
12801296 uint32_t OWmp; uint32_t OWL;
12811297 uint32_t OWOHmp; uint32_t OWOHL;
12821298};
12831299
12841300template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1285- // Compute magic values to divide by KW, KW*KH, OW, OW*OH
1286- init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1287- init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1301+ // Compute magic values to divide by OW, OW*OH
12881302 init_fastdiv_values(p.OW, p.OWmp, p.OWL);
12891303 init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
12901304}
@@ -1320,23 +1334,15 @@ struct vk_op_conv_transpose_2d_push_constants {
13201334 uint32_t nb2;
13211335 uint32_t nb3;
13221336
1323- // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1324- uint32_t KWmp; uint32_t KWL;
1325- uint32_t KWKHmp; uint32_t KWKHL;
1337+ // init_fastdiv_values constants for dividing by OW, OW*OH
13261338 uint32_t OWmp; uint32_t OWL;
13271339 uint32_t OWOHmp; uint32_t OWOHL;
1328- uint32_t s0mp; uint32_t s0L;
1329- uint32_t s1mp; uint32_t s1L;
13301340};
13311341
13321342template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1333- // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1334- init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1335- init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1343+ // Compute magic values to divide by OW, OW*OH
13361344 init_fastdiv_values(p.OW, p.OWmp, p.OWL);
13371345 init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1338- init_fastdiv_values(p.s0, p.s0mp, p.s0L);
1339- init_fastdiv_values(p.s1, p.s1mp, p.s1L);
13401346}
13411347
13421348struct vk_op_conv2d_dw_push_constants {
@@ -3874,22 +3880,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
38743880 switch (s) {
38753881 default:
38763882 case CONV_SHAPE_128x128:
3877- conv2d_BS_K = 128 ;
3878- conv2d_BS_NPQ = 128 ;
3883+ conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0] ;
3884+ conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1] ;
38793885 conv2d_BS_CRS = 16;
38803886 if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
38813887 conv2d_UNROLL = false;
38823888 }
38833889 break;
38843890 case CONV_SHAPE_64x32:
3885- conv2d_BS_K = 64 ;
3886- conv2d_BS_NPQ = 32 ;
3891+ conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0] ;
3892+ conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1] ;
38873893 conv2d_BS_CRS = 32;
38883894 conv2d_TS_K = 4;
38893895 break;
38903896 case CONV_SHAPE_32x256:
3891- conv2d_BS_K = 32 ;
3892- conv2d_BS_NPQ = 256 ;
3897+ conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0] ;
3898+ conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1] ;
38933899 conv2d_BS_CRS = 16;
38943900 break;
38953901 }
@@ -3923,10 +3929,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
39233929 std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
39243930
39253931#define CREATE_CONV(name, type_suffix, spv_suffix) \
3926- ggml_vk_create_pipeline( \
3927- device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3928- name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3929- sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3932+ for (auto &c : device->pipeline_##name##type_suffix[s]) { \
3933+ const vk_conv2d_pipeline_state &state = c.first; \
3934+ std::vector<uint32_t> spec_constants_cpy = spec_constants; \
3935+ spec_constants_cpy.push_back(state.s0); \
3936+ spec_constants_cpy.push_back(state.s1); \
3937+ spec_constants_cpy.push_back(state.p0); \
3938+ spec_constants_cpy.push_back(state.p1); \
3939+ spec_constants_cpy.push_back(state.d0); \
3940+ spec_constants_cpy.push_back(state.d1); \
3941+ spec_constants_cpy.push_back(state.KW); \
3942+ spec_constants_cpy.push_back(state.KH); \
3943+ ggml_vk_create_pipeline( \
3944+ device, c.second, #name #type_suffix, \
3945+ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3946+ sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
3947+ }
39303948#define CREATE_CONVS(spv_suffix) \
39313949 CREATE_CONV(conv2d, _f32, spv_suffix) \
39323950 CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
@@ -8566,7 +8584,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85668584
85678585 uint32_t tiles[CONV_SHAPE_COUNT];
85688586 for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
8569- tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32 [i]->wg_denoms [0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32 [i]->wg_denoms [1]);
8587+ tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms [i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms [i][1]);
85708588 }
85718589
85728590 // We can't query number of shader cores on Intel, use 32 as a placeholder
@@ -8581,19 +8599,45 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85818599 shape = CONV_SHAPE_64x32;
85828600 }
85838601
8602+ uint32_t KW = static_cast<uint32_t>(src0->ne[0]);
8603+ uint32_t KH = static_cast<uint32_t>(src0->ne[1]);
8604+ uint32_t s0 = static_cast<uint32_t>(dst->op_params[0]);
8605+ uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[1]) : static_cast<uint32_t>(dst->op_params[0]);
8606+ uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[2]) : 0;
8607+ uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[3]) : 0;
8608+ uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[4]) : 1;
8609+ uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[5]) : 1;
8610+
8611+ vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
8612+
8613+ std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
85848614 if (op == GGML_OP_CONV_2D) {
85858615 if (src0->type == GGML_TYPE_F32) {
8586- return ctx->device->pipeline_conv2d_f32[shape];
8616+ pipelines = & ctx->device->pipeline_conv2d_f32[shape];
85878617 } else if (src0->type == GGML_TYPE_F16) {
8588- return ctx->device->pipeline_conv2d_f16_f32[shape];
8618+ pipelines = & ctx->device->pipeline_conv2d_f16_f32[shape];
85898619 }
85908620 } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
85918621 if (src0->type == GGML_TYPE_F32) {
8592- return ctx->device->pipeline_conv_transpose_2d_f32[shape];
8622+ pipelines = & ctx->device->pipeline_conv_transpose_2d_f32[shape];
85938623 } else if (src0->type == GGML_TYPE_F16) {
8594- return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8624+ pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8625+ }
8626+ }
8627+
8628+ vk_pipeline pipeline = nullptr;
8629+
8630+ {
8631+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8632+ auto it = pipelines->find(conv2d_pipeline_state);
8633+ if (it != pipelines->end()) {
8634+ pipeline = it->second;
8635+ } else {
8636+ (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
85958637 }
85968638 }
8639+
8640+ return pipeline;
85978641 }
85988642 return nullptr;
85998643 case GGML_OP_CONV_2D_DW:
0 commit comments