@@ -1232,8 +1232,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
12321232 std::cerr << " ggml_vulkan: Compiling shaders" ;
12331233
12341234 // mulmat
1235- std::vector<uint32_t > l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1236- std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms;
1235+ std::vector<uint32_t > l_warptile, m_warptile, s_warptile,
1236+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1237+ std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms,
1238+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
12371239 uint32_t l_align, m_align, s_align;
12381240
12391241 l_warptile = { 128 , 128 , 128 , 16 , device->subgroup_size * 2 , 64 , 2 , 4 , 4 , device->subgroup_size };
@@ -1244,14 +1246,48 @@ static void ggml_vk_load_shaders(vk_device& device) {
12441246 m_warptile_mmq = { 128 , 64 , 64 , 32 , device->subgroup_size , 32 , 2 , 4 , 2 , device->subgroup_size };
12451247 s_warptile_mmq = { std::max (device->subgroup_size , 16u ), 32 , 32 , 32 , 32 , 32 , 2 , 2 , 2 , device->subgroup_size };
12461248
1247- l_wg_denoms = {128 , 128 , 1 };
1248- m_wg_denoms = { 64 , 64 , 1 };
1249- s_wg_denoms = { 32 , 32 , 1 };
1249+ l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
1250+ m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
1251+ s_mmq_wg_denoms = s_wg_denoms = { 32 , 32 , 1 };
12501252
12511253 l_align = 128 ;
12521254 m_align = 64 ;
12531255 s_align = 32 ;
12541256
1257+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1258+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1259+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1260+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
1261+ // size there's enough room for everything, and we assert for this.
1262+ uint32_t shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1263+ if (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1264+ l_warptile = m_warptile;
1265+ l_wg_denoms = m_wg_denoms;
1266+ shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1267+ GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1268+ }
1269+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1270+ // assert mul_mat_mat_id shaders will fit.
1271+ GGML_ASSERT (shmem_needed + 3072 *4 <= device->properties .limits .maxComputeSharedMemorySize );
1272+ }
1273+
1274+ shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1275+ if (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1276+ if (device->properties .limits .maxComputeSharedMemorySize == 32768 ) {
1277+ l_warptile_mmq = m_warptile_mmq;
1278+ l_mmq_wg_denoms = m_mmq_wg_denoms;
1279+ } else {
1280+ l_warptile_mmq = s_warptile_mmq;
1281+ l_mmq_wg_denoms = s_mmq_wg_denoms;
1282+ }
1283+ shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1284+ GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1285+ }
1286+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1287+ // assert mul_mat_mat_id shaders will fit.
1288+ GGML_ASSERT (shmem_needed + 3072 *4 <= device->properties .limits .maxComputeSharedMemorySize );
1289+ }
1290+
12551291 device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
12561292 device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
12571293
@@ -1299,35 +1335,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
12991335 CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13001336 CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13011337
1302- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1303- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1304- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1305- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1306- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1307-
1308- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1309- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1310- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1311- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1312- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1313- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1314-
1315- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1316- CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1317- CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1318-
1319- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1320- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1321- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1322- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1323- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1324-
1325- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1326- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1327- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1328- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1329- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1330- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1338+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1339+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1340+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1341+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1342+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1343+
1344+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1345+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1346+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1347+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350+
1351+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1352+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1353+ CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1354+ CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1355+ CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1356+
1357+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1358+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1359+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1360+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1361+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1362+
1363+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1364+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369+ }
13311370#undef CREATE_MM
13321371 } else {
13331372 // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1344,35 +1383,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
13441383 CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13451384 CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13461385
1347- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1351- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1352-
1353- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1354- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1355- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1356- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1357- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1358- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1359-
1360- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1361- CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1362- CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1363-
1364- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369-
1370- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1371- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1372- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1373- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1374- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1375- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1386+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1387+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1388+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1389+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1390+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1391+
1392+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1393+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1394+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1395+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1396+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1397+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1398+
1399+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1400+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1401+ CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1402+ CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1403+ CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1404+
1405+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1406+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1407+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1408+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1409+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1410+
1411+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1412+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1413+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1414+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1415+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1416+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1417+ }
13761418#undef CREATE_MM
13771419 }
13781420
@@ -6541,6 +6583,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
65416583 case GGML_OP_MUL_MAT:
65426584 case GGML_OP_MUL_MAT_ID:
65436585 {
6586+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6587+ if (op->op == GGML_OP_MUL_MAT_ID &&
6588+ ggml_vk_get_device (ctx->device )->properties .limits .maxComputeSharedMemorySize < 32768 ) {
6589+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
6590+ return false ;
6591+ }
65446592 switch (op->src [0 ]->type ) {
65456593 case GGML_TYPE_F32:
65466594 case GGML_TYPE_F16:
0 commit comments