@@ -1232,8 +1232,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
12321232    std::cerr << " ggml_vulkan: Compiling shaders" 
12331233
12341234    //  mulmat
1235-     std::vector<uint32_t > l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1236-     std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms;
1235+     std::vector<uint32_t > l_warptile, m_warptile, s_warptile,
1236+                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1237+     std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms,
1238+                             l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
12371239    uint32_t  l_align, m_align, s_align;
12381240
12391241    l_warptile = { 128 , 128 , 128 , 16 , device->subgroup_size  * 2 , 64 , 2 , 4 , 4 , device->subgroup_size  };
@@ -1244,14 +1246,48 @@ static void ggml_vk_load_shaders(vk_device& device) {
12441246    m_warptile_mmq = { 128 ,  64 ,  64 , 32 , device->subgroup_size , 32 , 2 , 4 , 2 , device->subgroup_size  };
12451247    s_warptile_mmq = { std::max (device->subgroup_size , 16u ),  32 ,  32 , 32 , 32 , 32 , 2 , 2 , 2 , device->subgroup_size  };
12461248
1247-     l_wg_denoms = {128 , 128 , 1  };
1248-     m_wg_denoms = { 64 ,  64 , 1  };
1249-     s_wg_denoms = { 32 ,  32 , 1  };
1249+     l_mmq_wg_denoms =  l_wg_denoms = {128 , 128 , 1  };
1250+     m_mmq_wg_denoms =  m_wg_denoms = { 64 ,  64 , 1  };
1251+     s_mmq_wg_denoms =  s_wg_denoms = { 32 ,  32 , 1  };
12501252
12511253    l_align = 128 ;
12521254    m_align =  64 ;
12531255    s_align =  32 ;
12541256
1257+     //  Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1258+     //  and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1259+     //  This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1260+     //  But the numbers happen to work out for 32KB shared memory size that when using the medium
1261+     //  size there's enough room for everything, and we assert for this.
1262+     uint32_t  shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1263+     if  (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1264+         l_warptile = m_warptile;
1265+         l_wg_denoms = m_wg_denoms;
1266+         shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1267+         GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1268+     }
1269+     if  (device->properties .limits .maxComputeSharedMemorySize  >= 32768 ) {
1270+         //  assert mul_mat_mat_id shaders will fit.
1271+         GGML_ASSERT (shmem_needed + 3072 *4  <= device->properties .limits .maxComputeSharedMemorySize );
1272+     }
1273+ 
1274+     shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1275+     if  (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1276+         if  (device->properties .limits .maxComputeSharedMemorySize  == 32768 ) {
1277+             l_warptile_mmq = m_warptile_mmq;
1278+             l_mmq_wg_denoms = m_mmq_wg_denoms;
1279+         } else  {
1280+             l_warptile_mmq = s_warptile_mmq;
1281+             l_mmq_wg_denoms = s_mmq_wg_denoms;
1282+         }
1283+         shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1284+         GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1285+     }
1286+     if  (device->properties .limits .maxComputeSharedMemorySize  >= 32768 ) {
1287+         //  assert mul_mat_mat_id shaders will fit.
1288+         GGML_ASSERT (shmem_needed + 3072 *4  <= device->properties .limits .maxComputeSharedMemorySize );
1289+     }
1290+ 
12551291    device->pipeline_matmul_f32  = std::make_shared<vk_matmul_pipeline_struct>();
12561292    device->pipeline_matmul_f32_f16  = std::make_shared<vk_matmul_pipeline_struct>();
12571293
@@ -1299,35 +1335,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
12991335        CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13001336        CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13011337
1302-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1303-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1304-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1305-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1306-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1307- 
1308-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1309-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1310-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1311-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1312-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1313-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1314- 
1315-         CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1316-         CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1317-         CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1318- 
1319-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1320-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1321-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1322-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1323-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1324- 
1325-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1326-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1327-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1328-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1329-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1330-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1338+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1339+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1340+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1341+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1342+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1343+ 
1344+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1345+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1346+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1347+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350+ 
1351+         //  If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1352+         if  (device->properties .limits .maxComputeSharedMemorySize  >= 32768 ) {
1353+             CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1354+             CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1355+             CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1356+ 
1357+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1358+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1359+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1360+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1361+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1362+ 
1363+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1364+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369+         }
13311370#undef  CREATE_MM
13321371    } else  {
13331372        //  Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1344,35 +1383,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
13441383        CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13451384        CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
13461385
1347-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1351-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1352- 
1353-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1354-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1355-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1356-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1357-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1358-         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1359- 
1360-         CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1361-         CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1362-         CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1363- 
1364-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369- 
1370-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1371-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1372-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1373-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1374-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1375-         CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1386+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1387+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1388+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1389+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1390+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1391+ 
1392+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1393+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1394+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1395+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1396+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1397+         CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1398+ 
1399+         //  If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1400+         if  (device->properties .limits .maxComputeSharedMemorySize  >= 32768 ) {
1401+             CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1402+             CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1403+             CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1404+ 
1405+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1406+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1407+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1408+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1409+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1410+ 
1411+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1412+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1413+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1414+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1415+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1416+             CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1417+         }
13761418#undef  CREATE_MM
13771419    }
13781420
@@ -6541,6 +6583,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
65416583        case  GGML_OP_MUL_MAT:
65426584        case  GGML_OP_MUL_MAT_ID:
65436585            {
6586+                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6587+                 if  (op->op  == GGML_OP_MUL_MAT_ID &&
6588+                     ggml_vk_get_device (ctx->device )->properties .limits .maxComputeSharedMemorySize  < 32768 ) {
6589+                     //  If there's not enough shared memory for row_ids and the result tile, fallback to CPU
6590+                     return  false ;
6591+                 }
65446592                switch  (op->src [0 ]->type ) {
65456593                    case  GGML_TYPE_F32:
65466594                    case  GGML_TYPE_F16:
0 commit comments