Skip to content

Commit d55e584

Browse files
CUDA: use arch list for feature availability check
1 parent 19d3c82 commit d55e584

File tree

6 files changed

+80
-20
lines changed

6 files changed

+80
-20
lines changed

ggml/src/ggml-common.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
473473
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
474474
GGML_TABLE_END()
475475

476-
//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
477476
GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
478477
0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
479478
0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
508507
0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
509508
0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
510509
GGML_TABLE_END()
511-
//#endif
512510

513511

514512
GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)

ggml/src/ggml-cuda/common.cuh

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,62 @@
7171
#define GGML_CUDA_CC_QY1 210
7272
#define GGML_CUDA_CC_QY2 220
7373

74+
#ifdef __CUDA_ARCH_LIST__
75+
constexpr bool ggml_cuda_has_arch_impl(int) {
76+
return false;
77+
}
78+
79+
template<class ... Archs>
80+
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
81+
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
82+
}
83+
84+
constexpr bool ggml_cuda_has_arch(const int arch) {
85+
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
86+
}
87+
88+
static int ggml_cuda_highest_compiled_arch(const int arch) {
89+
switch (arch) {
90+
case 1200: if (ggml_cuda_has_arch(1200)) return 1200; [[fallthrough]];
91+
case 1010: if (ggml_cuda_has_arch(1010)) return 1010; [[fallthrough]];
92+
case 1000: if (ggml_cuda_has_arch(1000)) return 1000; [[fallthrough]];
93+
case 900: if (ggml_cuda_has_arch( 900)) return 900; [[fallthrough]];
94+
case 890: if (ggml_cuda_has_arch( 890)) return 890; [[fallthrough]];
95+
case 870: if (ggml_cuda_has_arch( 870)) return 870; [[fallthrough]];
96+
case 860: if (ggml_cuda_has_arch( 860)) return 860; [[fallthrough]];
97+
case 800: if (ggml_cuda_has_arch( 800)) return 800; [[fallthrough]];
98+
case 750: if (ggml_cuda_has_arch( 750)) return 750; [[fallthrough]];
99+
case 720: if (ggml_cuda_has_arch( 720)) return 720; [[fallthrough]];
100+
case 700: if (ggml_cuda_has_arch( 700)) return 700; [[fallthrough]];
101+
case 620: if (ggml_cuda_has_arch( 620)) return 620; [[fallthrough]];
102+
case 610: if (ggml_cuda_has_arch( 610)) return 610; [[fallthrough]];
103+
case 600: if (ggml_cuda_has_arch( 600)) return 600; [[fallthrough]];
104+
case 530: if (ggml_cuda_has_arch( 530)) return 530; [[fallthrough]];
105+
case 520: if (ggml_cuda_has_arch( 520)) return 520; [[fallthrough]];
106+
case 500: if (ggml_cuda_has_arch( 500)) return 500; [[fallthrough]];
107+
case 370: if (ggml_cuda_has_arch( 370)) return 370; [[fallthrough]];
108+
case 350: if (ggml_cuda_has_arch( 350)) return 350; [[fallthrough]];
109+
case 320: if (ggml_cuda_has_arch( 320)) return 320; [[fallthrough]];
110+
case 300: if (ggml_cuda_has_arch( 300)) return 300; [[fallthrough]];
111+
case 210: if (ggml_cuda_has_arch( 210)) return 210; [[fallthrough]];
112+
case 200: if (ggml_cuda_has_arch( 200)) return 200; [[fallthrough]];
113+
case 130: if (ggml_cuda_has_arch( 130)) return 130; [[fallthrough]];
114+
case 120: if (ggml_cuda_has_arch( 120)) return 120; [[fallthrough]];
115+
case 110: if (ggml_cuda_has_arch( 110)) return 110; [[fallthrough]];
116+
case 100: if (ggml_cuda_has_arch( 100)) return 100;
117+
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
118+
119+
default: GGML_ABORT("unknown CUDA arch: %d", arch);
120+
}
121+
}
122+
#else
123+
static int ggml_cuda_highest_compiled_arch(const int arch) {
124+
return arch;
125+
}
126+
#endif // __CUDA_ARCH_LIST__
127+
128+
// ---------------------------------------------------------------------------------------------------------
129+
74130
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
75131

76132
#if defined(_MSC_VER)
@@ -162,18 +218,22 @@ typedef float2 dfloat2;
162218
#define FLASH_ATTN_AVAILABLE
163219
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
164220

165-
static constexpr bool fast_fp16_available(const int cc) {
166-
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
221+
static bool fp16_available(const int cc) {
222+
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
223+
}
224+
225+
static bool fast_fp16_available(const int cc) {
226+
return fp16_available(cc) && cc != 610;
167227
}
168228

169229
// Any FP16 tensor cores are available.
170-
static constexpr bool fp16_mma_available(const int cc) {
171-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
230+
static bool fp16_mma_available(const int cc) {
231+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
172232
}
173233

174234
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
175-
static constexpr bool new_mma_available(const int cc) {
176-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
235+
static bool new_mma_available(const int cc) {
236+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
177237
}
178238

179239
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {

ggml/src/ggml-cuda/convert.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
599599
case GGML_TYPE_Q5_1:
600600
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
601601
case GGML_TYPE_Q8_0:
602-
if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
602+
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
603603
return dequantize_block_q8_0_f16_cuda;
604604
}
605605
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3205,8 +3205,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32053205
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
32063206
return true;
32073207
}
3208-
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3209-
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
3208+
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc);
32103209
}
32113210
case GGML_OP_CROSS_ENTROPY_LOSS:
32123211
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
1818
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
1919

2020
int id = ggml_cuda_get_device();
21-
const int compute_capability = ggml_cuda_info().devices[id].cc;
21+
const int cc = ggml_cuda_info().devices[id].cc;
2222

2323
// the main device has a larger memory buffer to hold the results from all GPUs
2424
// nrows_dst == nrows of the matrix that the kernel writes into
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
2727
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
2828
// Also its fixup needs to allocate a temporary buffer in the memory pool.
2929
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
30-
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
30+
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
31+
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
3132
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
3233

3334
switch (src0->type) {
@@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
136137
return true;
137138
}
138139

139-
if (cc < GGML_CUDA_CC_DP4A) {
140+
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
140141
return false;
141142
}
142143

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ struct tile_x_sizes {
8686
int sc;
8787
};
8888

89-
static constexpr int get_mmq_x_max_host(const int cc) {
89+
static int get_mmq_x_max_host(const int cc) {
9090
return new_mma_available(cc) ? 128 :
91+
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
9192
#ifdef GGML_CUDA_FORCE_MMQ
92-
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
93+
128 : 64;
9394
#else
94-
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
95+
MMQ_DP4A_MAX_BATCH_SIZE : 64;
9596
#endif // GGML_CUDA_FORCE_MMQ
9697
}
9798

@@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
119120
#endif // NEW_MMA_AVAILABLE
120121
}
121122

122-
static constexpr int get_mmq_y_host(const int cc) {
123-
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
123+
static int get_mmq_y_host(const int cc) {
124+
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
125+
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
124126
}
125127

126128
static constexpr __device__ int get_mmq_y_device() {
@@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
28282830
const int mmq_x_max = get_mmq_x_max_host(cc);
28292831
const int mmq_y = get_mmq_y_host(cc);
28302832
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
2831-
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
2833+
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
28322834

28332835
int mmq_x_best = 0;
28342836
int nparts_best = INT_MAX;

0 commit comments

Comments
 (0)