Skip to content

Commit 8bd52b6

Browse files
committed
restirct to amd
1 parent 70774e1 commit 8bd52b6

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

ggml/src/ggml-cuda/mmv.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
456456
return ne11 <= 4;
457457
}
458458
return ne11 <= 3;
459-
} else if (fp32_mma_hardware_available(cc)) {
460-
return ne11 <= 3;
459+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
460+
if (fp32_mma_hardware_available(cc)) {
461+
return ne11 <= 3;
462+
}
463+
return ne11 <= 8;
461464
}
462465
return ne11 <= 8;
463466
case GGML_TYPE_F16:
@@ -470,11 +473,14 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
470473
return src0_small && ne11 <= 3;
471474
}
472475
return ne11 <= 8;
473-
} else if (fp16_mma_hardware_available(cc)) {
474-
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
475-
return ne11 <= 5;
476+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
477+
if (fp16_mma_hardware_available(cc)) {
478+
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
479+
return ne11 <= 5;
480+
}
481+
return ne11 <= 2;
476482
}
477-
return ne11 <= 2;
483+
return ne11 <= 8;
478484
}
479485
return ne11 <= 8;
480486
case GGML_TYPE_BF16:
@@ -487,8 +493,11 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
487493
return src0_small && ne11 <= 3;
488494
}
489495
return ne11 <= 8;
490-
} else if (bf16_mma_hardware_available(cc)) {
491-
return ne11 <= 3;
496+
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
497+
if (bf16_mma_hardware_available(cc)) {
498+
return ne11 <= 3;
499+
}
500+
return ne11 <= 8;
492501
}
493502
return ne11 <= 8;
494503
default:

0 commit comments

Comments
 (0)