Skip to content

Commit 165edb3

Browse files
CUDA: use arch list for feature availability check
1 parent 19d3c82 commit 165edb3

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,62 @@
7171
#define GGML_CUDA_CC_QY1 210
7272
#define GGML_CUDA_CC_QY2 220
7373

74+
#ifdef __CUDA_ARCH_LIST__
75+
constexpr bool ggml_cuda_has_arch_impl(int) {
76+
return false;
77+
}
78+
79+
template<class ... Archs>
80+
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
81+
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
82+
}
83+
84+
constexpr bool ggml_cuda_has_arch(const int arch) {
85+
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
86+
}
87+
88+
static int ggml_cuda_highest_compiled_arch(const int arch) {
89+
switch (arch) {
90+
case 1200: if (ggml_cuda_has_arch(1200)) return 1200; [[fallthrough]];
91+
case 1010: if (ggml_cuda_has_arch(1010)) return 1010; [[fallthrough]];
92+
case 1000: if (ggml_cuda_has_arch(1000)) return 1000; [[fallthrough]];
93+
case 900: if (ggml_cuda_has_arch( 900)) return 900; [[fallthrough]];
94+
case 890: if (ggml_cuda_has_arch( 890)) return 890; [[fallthrough]];
95+
case 870: if (ggml_cuda_has_arch( 870)) return 870; [[fallthrough]];
96+
case 860: if (ggml_cuda_has_arch( 860)) return 860; [[fallthrough]];
97+
case 800: if (ggml_cuda_has_arch( 800)) return 800; [[fallthrough]];
98+
case 750: if (ggml_cuda_has_arch( 750)) return 750; [[fallthrough]];
99+
case 720: if (ggml_cuda_has_arch( 720)) return 720; [[fallthrough]];
100+
case 700: if (ggml_cuda_has_arch( 700)) return 700; [[fallthrough]];
101+
case 620: if (ggml_cuda_has_arch( 620)) return 620; [[fallthrough]];
102+
case 610: if (ggml_cuda_has_arch( 610)) return 610; [[fallthrough]];
103+
case 600: if (ggml_cuda_has_arch( 600)) return 600; [[fallthrough]];
104+
case 530: if (ggml_cuda_has_arch( 530)) return 530; [[fallthrough]];
105+
case 520: if (ggml_cuda_has_arch( 520)) return 520; [[fallthrough]];
106+
case 500: if (ggml_cuda_has_arch( 500)) return 500; [[fallthrough]];
107+
case 370: if (ggml_cuda_has_arch( 370)) return 370; [[fallthrough]];
108+
case 350: if (ggml_cuda_has_arch( 350)) return 350; [[fallthrough]];
109+
case 320: if (ggml_cuda_has_arch( 320)) return 320; [[fallthrough]];
110+
case 300: if (ggml_cuda_has_arch( 300)) return 300; [[fallthrough]];
111+
case 210: if (ggml_cuda_has_arch( 210)) return 210; [[fallthrough]];
112+
case 200: if (ggml_cuda_has_arch( 200)) return 200; [[fallthrough]];
113+
case 130: if (ggml_cuda_has_arch( 130)) return 130; [[fallthrough]];
114+
case 120: if (ggml_cuda_has_arch( 120)) return 120; [[fallthrough]];
115+
case 110: if (ggml_cuda_has_arch( 110)) return 110; [[fallthrough]];
116+
case 100: if (ggml_cuda_has_arch( 100)) return 100;
117+
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
118+
119+
default: GGML_ABORT("unknown CUDA arch: %d", arch);
120+
}
121+
}
122+
#else
123+
static int ggml_cuda_highest_compiled_arch(const int arch) {
124+
return arch;
125+
}
126+
#endif // __CUDA_ARCH_LIST__
127+
128+
// ---------------------------------------------------------------------------------------------------------
129+
74130
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
75131

76132
#if defined(_MSC_VER)
@@ -162,18 +218,18 @@ typedef float2 dfloat2;
162218
#define FLASH_ATTN_AVAILABLE
163219
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
164220

165-
static constexpr bool fast_fp16_available(const int cc) {
166-
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
221+
static bool fast_fp16_available(const int cc) {
222+
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL && cc != 610;
167223
}
168224

169225
// Any FP16 tensor cores are available.
170-
static constexpr bool fp16_mma_available(const int cc) {
171-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
226+
static bool fp16_mma_available(const int cc) {
227+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
172228
}
173229

174230
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
175-
static constexpr bool new_mma_available(const int cc) {
176-
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
231+
static bool new_mma_available(const int cc) {
232+
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
177233
}
178234

179235
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
136136
return true;
137137
}
138138

139-
if (cc < GGML_CUDA_CC_DP4A) {
139+
if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
140140
return false;
141141
}
142142

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct tile_x_sizes {
8686
int sc;
8787
};
8888

89-
static constexpr int get_mmq_x_max_host(const int cc) {
89+
static int get_mmq_x_max_host(const int cc) {
9090
return new_mma_available(cc) ? 128 :
9191
#ifdef GGML_CUDA_FORCE_MMQ
9292
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;

0 commit comments

Comments
 (0)