@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
15201520 !ggml_is_transposed (op->src [1 ]) &&
15211521 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
15221522 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1523- props_dev->has_simdgroup_mm &&
1524- op->src [1 ]->type == GGML_TYPE_F32 &&
1525- ne00 % 32 == 0 && ne00 >= 64 &&
1523+ props_dev->has_simdgroup_mm && ne00 >= 64 &&
15261524 (ne11 > ne11_mm_min || (ggml_is_quantized (op->src [0 ]->type ) && ne12 > 1 ))) {
15271525 // printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
15281526
15291527 // some Metal matrix data types require aligned pointers
15301528 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1531- switch (op->src [0 ]->type ) {
1532- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1533- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1534- case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1535- default : break ;
1536- }
1529+ // switch (op->src[0]->type) {
1530+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1531+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1532+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1533+ // default: break;
1534+ // }
15371535
1538- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm (lib, op-> src [ 0 ]-> type , op-> src [ 1 ]-> type );
1536+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm (lib, op);
15391537
15401538 ggml_metal_kargs_mul_mm args = {
15411539 /* .ne00 =*/ ne00,
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16551653 GGML_ASSERT (!ggml_is_transposed (op->src [0 ]));
16561654 GGML_ASSERT (!ggml_is_transposed (op->src [1 ]));
16571655
1658- GGML_ASSERT (op->src [1 ]->type == GGML_TYPE_F32);
1659-
16601656 GGML_ASSERT (ne03 == 1 );
16611657 GGML_ASSERT (ne13 == 1 );
16621658
@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16741670 // ne21 = n_rows (batch size)
16751671 const int ne21_mm_id_min = 32 ;
16761672
1677- if (props_dev->has_simdgroup_mm &&
1678- ne00 % 32 == 0 && ne00 >= 64 &&
1679- (ne21 >= ne21_mm_id_min)) {
1680- GGML_ASSERT (ne00 % 4 == 0 );
1681-
1673+ if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
16821674 // some Metal matrix data types require aligned pointers
16831675 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1684- switch (op->src [0 ]->type ) {
1685- case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1686- case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1687- case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1688- default : break ;
1689- }
1676+ // switch (op->src[0]->type) {
1677+ // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1678+ // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1679+ // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1680+ // default: break;
1681+ // }
16901682
16911683 // extra buffers for intermediate id mapping
16921684 ggml_metal_buffer_id bid_tpe = bid_dst;
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
17301722 ggml_metal_op_concurrency_reset (ctx);
17311723
17321724 {
1733- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id (lib, op-> src [ 0 ]-> type , GGML_TYPE_F16 );
1725+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id (lib, op);
17341726
17351727 ggml_metal_kargs_mul_mm_id args = {
17361728 /* .ne00 =*/ ne00,
0 commit comments