Skip to content

Commit aa35feb

Browse files
committed
Feat: Remove warnings, deprecated __AMDGCN_WAVEFRONT_SIZE
1 parent 0215a80 commit aa35feb

File tree

3 files changed

+91
-82
lines changed

3 files changed

+91
-82
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ static bool cp_async_available(const int cc) {
271271
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
272272
}
273273

274-
static constexpr __host__ __device__ int ggml_cuda_get_physical_warp_size() {
275-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
276-
return __AMDGCN_WAVEFRONT_SIZE;
274+
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
275+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
276+
return 64;
277277
#else
278278
return 32;
279-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
279+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
280280
}
281281

282282
[[noreturn]]

ggml/src/ggml-cuda/mma.cuh

Lines changed: 53 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -64,71 +64,70 @@ namespace ggml_cuda_mma {
6464

6565
template <int I_, int J_, typename T>
6666
struct tile {
67-
static constexpr int warp_size = ggml_cuda_get_physical_warp_size();
6867
static constexpr int I = I_;
6968
static constexpr int J = J_;
70-
static constexpr int ne = I * J / warp_size;
69+
70+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
71+
static constexpr int ne = I * J / 64;
72+
T x[ne] = {0};
73+
74+
static __device__ __forceinline__ int get_i(const int l) {
75+
if constexpr (I == 16 && J == 8) {
76+
return threadIdx.x % 16;
77+
} else if constexpr (I == 32 && J == 4) {
78+
return threadIdx.x % 32;
79+
} else if constexpr (I == 16 && J == 16) {
80+
return 4 * (threadIdx.x / 16) + l;
81+
} else if constexpr (I == 32 && J == 32) {
82+
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
83+
} else {
84+
static_assert(I == -1 && J == -1, "template specialization not implemented");
85+
}
86+
}
87+
88+
static __device__ __forceinline__ int get_j(const int l) {
89+
if constexpr (I == 16 && J == 8) {
90+
return 2 * (threadIdx.x / 16) + l;
91+
} else if constexpr (I == 32 && J == 4) {
92+
return 2 * (threadIdx.x / 32) + l;
93+
} else if constexpr (I == 16 && J == 16) {
94+
return threadIdx.x % 16;
95+
} else if constexpr (I == 32 && J == 32) {
96+
return threadIdx.x % 32;
97+
} else {
98+
static_assert(I == -1 && J == -1, "template specialization not implemented");
99+
}
100+
}
101+
#else
102+
static constexpr int ne = I * J / 32;
71103
T x[ne] = {0};
72104

73105
static __device__ __forceinline__ int get_i(const int l) {
74-
if constexpr (warp_size == 32) {
75-
if constexpr (I == 8 && (J == 4 || J == 8)) {
76-
return threadIdx.x / 4;
77-
} else if constexpr (I == 16 && J == 8) {
78-
return (l / 2) * 8 + threadIdx.x / 4;
79-
} else if constexpr (I == 16 && J == 16) {
80-
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
81-
} else {
82-
static_assert(I == -1 && J == -1, "template specialization not implemented");
83-
}
84-
} else if constexpr (warp_size == 64) {
85-
if constexpr (I == 8 && (J == 4 || J == 8)) { // Remove this case
86-
return threadIdx.x / 4;
87-
} else if constexpr (I == 16 && J == 8) {
88-
return threadIdx.x % 16;
89-
} else if constexpr (I == 32 && J == 4) {
90-
return threadIdx.x % 32;
91-
} else if constexpr (I == 16 && J == 16) {
92-
return 4 * (threadIdx.x / 16) + l;
93-
} else if constexpr (I == 32 && J == 32) {
94-
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
95-
} else {
96-
static_assert(I == -1 && J == -1, "template specialization not implemented");
97-
}
106+
if constexpr (I == 8 && (J == 4 || J == 8)) {
107+
return threadIdx.x / 4;
108+
} else if constexpr (I == 16 && J == 8) {
109+
return (l / 2) * 8 + threadIdx.x / 4;
110+
} else if constexpr (I == 16 && J == 16) {
111+
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
112+
} else {
113+
static_assert(I == -1 && J == -1, "template specialization not implemented");
98114
}
99115
}
100116

101117
static __device__ __forceinline__ int get_j(const int l) {
102-
if constexpr (warp_size == 32) {
103-
if constexpr (I == 8 && J == 4) {
104-
return threadIdx.x % 4;
105-
} else if constexpr (I == 8 && J == 8) {
106-
return 4 * l + threadIdx.x % 4;
107-
} else if constexpr (I == 16 && J == 8) {
108-
return 2 * (threadIdx.x % 4) + l % 2;
109-
} else if constexpr (I == 16 && J == 16) {
110-
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
111-
} else {
112-
static_assert(I == -1 && J == -1, "template specialization not implemented");
113-
}
114-
} else if constexpr (warp_size == 64) {
115-
if constexpr (I == 8 && J == 4) { // Remove this case
116-
return threadIdx.x % 4;
117-
} else if constexpr (I == 8 && J == 8) { // Remove this case
118-
return 4 * l + threadIdx.x % 4;
119-
} else if constexpr (I == 16 && J == 8) {
120-
return 2 * (threadIdx.x / 16) + l;
121-
} else if constexpr (I == 32 && J == 4) {
122-
return 2 * (threadIdx.x / 32) + l;
123-
} else if constexpr (I == 16 && J == 16) {
124-
return threadIdx.x % 16;
125-
} else if constexpr (I == 32 && J == 32) {
126-
return threadIdx.x % 32;
127-
} else {
128-
static_assert(I == -1 && J == -1, "template specialization not implemented");
129-
}
118+
if constexpr (I == 8 && J == 4) {
119+
return threadIdx.x % 4;
120+
} else if constexpr (I == 8 && J == 8) {
121+
return 4 * l + threadIdx.x % 4;
122+
} else if constexpr (I == 16 && J == 8) {
123+
return 2 * (threadIdx.x % 4) + l % 2;
124+
} else if constexpr (I == 16 && J == 16) {
125+
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
126+
} else {
127+
static_assert(I == -1 && J == -1, "template specialization not implemented");
130128
}
131129
}
130+
#endif
132131
};
133132

134133
template <int I_, int J_>

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ static int mmq_get_granularity_host(ggml_type type, const int mmq_x, const int c
253253
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
254254
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
255255
return mmq_x >= 192 ? 64 : 32;
256+
default:
257+
return 0;
256258
}
257259
} else if (new_mma_available(cc) && mmq_x >= 48) {
258260
return 16;
@@ -285,6 +287,8 @@ static constexpr __device__ int mmq_get_granularity_device(ggml_type type, const
285287
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
286288
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
287289
return mmq_x >= 192 ? 64 : 32;
290+
default:
291+
return 0;
288292
}
289293
}
290294
#elif defined(NEW_MMA_AVAILABLE)
@@ -323,6 +327,8 @@ static int get_mmq_nwarps_host(ggml_type type, const int cc) {
323327
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
324328
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
325329
return 4;
330+
default:
331+
return 0;
326332
}
327333
} else {
328334
return 8;
@@ -355,6 +361,8 @@ static constexpr __device__ int get_mmq_nwarps_device(ggml_type type) {
355361
case GGML_TYPE_IQ2_XS: // vec_dot_q8_0_16_q8_1_mma
356362
case GGML_TYPE_IQ2_S: // vec_dot_q8_0_16_q8_1_mma
357363
return 4;
364+
default:
365+
return 0;
358366
}
359367
}
360368
#else
@@ -3123,16 +3131,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
31233131

31243132
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
31253133

3126-
template <ggml_type type, int mmq_x, int warp_size, bool need_check>
3134+
template <ggml_type type, int mmq_x, bool need_check>
31273135
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
3128-
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) || defined(GCN)
3129-
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 2)
3136+
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
3137+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*get_mmq_nwarps_device(type), 2)
31303138
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
31313139
#else
31323140
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
3133-
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 1)
3141+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*get_mmq_nwarps_device(type), 1)
31343142
#else
3135-
__launch_bounds__(warp_size*get_mmq_nwarps_device(type), 2)
3143+
__launch_bounds__(ggml_cuda_get_physical_warp_size()*get_mmq_nwarps_device(type), 2)
31363144
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
31373145
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
31383146
static __global__ void mul_mat_q(
@@ -3149,6 +3157,7 @@ static __global__ void mul_mat_q(
31493157
}
31503158

31513159
constexpr int nwarps = get_mmq_nwarps_device(type);
3160+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
31523161

31533162
constexpr int qk = ggml_cuda_type_traits<type>::qk;
31543163
constexpr int mmq_y = get_mmq_y_device();
@@ -3373,7 +3382,7 @@ static __global__ void mul_mat_q(
33733382
}
33743383

33753384

3376-
template <ggml_type type, int mmq_x, int warp_size, bool need_check>
3385+
template <ggml_type type, int mmq_x, bool need_check>
33773386
static __global__ void mul_mat_q_stream_k_fixup(
33783387
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
33793388
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
@@ -3384,6 +3393,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
33843393
const int64_t blocks_per_ne00 = ncols_x / qk;
33853394

33863395
constexpr int nwarps = get_mmq_nwarps_device(type);
3396+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
33873397

33883398
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
33893399

@@ -3531,8 +3541,8 @@ struct mmq_args {
35313541
bool use_stream_k;
35323542
};
35333543

3534-
template<ggml_type type, int warp_size>
3535-
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int nwarps) {
3544+
template<ggml_type type>
3545+
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
35363546
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
35373547
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
35383548
const size_t nbs_ids = mmq_x*sizeof(int);
@@ -3546,19 +3556,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
35463556
const int id = ggml_cuda_get_device();
35473557
const int cc = ggml_cuda_info().devices[id].cc;
35483558
const int nsm = ggml_cuda_info().devices[id].nsm;
3549-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3559+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
35503560
const int nwarps = get_mmq_nwarps_host(type, cc);
35513561
const int mmq_y = get_mmq_y_host(cc);
35523562

35533563
const dim3 block_dims(warp_size, nwarps, 1);
35543564

3555-
const int nbytes_shared = mmq_get_nbytes_shared<type, warp_size>(mmq_x, mmq_y, cc, nwarps);
3565+
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
35563566

35573567
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
35583568
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
35593569
if (!shared_memory_limit_raised[id]) {
3560-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, warp_size, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3561-
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, warp_size, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3570+
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
3571+
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
35623572
shared_memory_limit_raised[id] = true;
35633573
}
35643574
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
@@ -3576,14 +3586,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
35763586
if (!args.use_stream_k) {
35773587
if (args.nrows_x % mmq_y == 0) {
35783588
constexpr bool need_check = false;
3579-
mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3589+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
35803590
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
35813591
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
35823592
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
35833593
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
35843594
} else {
35853595
constexpr bool need_check = true;
3586-
mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
3596+
mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
35873597
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
35883598
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
35893599
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3603,7 +3613,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36033613

36043614
if (args.nrows_x % mmq_y == 0) {
36053615
constexpr bool need_check = false;
3606-
mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3616+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
36073617
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
36083618
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
36093619
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3613,12 +3623,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36133623
return;
36143624
}
36153625

3616-
mul_mat_q_stream_k_fixup<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3626+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
36173627
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
36183628
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
36193629
} else {
36203630
constexpr bool need_check = true;
3621-
mul_mat_q<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
3631+
mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
36223632
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
36233633
args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
36243634
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3628,19 +3638,19 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
36283638
return;
36293639
}
36303640

3631-
mul_mat_q_stream_k_fixup<type, mmq_x, warp_size, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
3641+
mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
36323642
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
36333643
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
36343644
}
36353645
}
36363646

36373647
template <ggml_type type>
36383648
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
3639-
const int id = ggml_cuda_get_device();
3640-
const int cc = ggml_cuda_info().devices[id].cc;
3641-
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3642-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3643-
const int nwarps = get_mmq_nwarps_host(type, cc);
3649+
const int id = ggml_cuda_get_device();
3650+
const int cc = ggml_cuda_info().devices[id].cc;
3651+
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
3652+
const int warp_size = ggml_cuda_info().devices[id].warp_size;
3653+
const int nwarps = get_mmq_nwarps_host(type, cc);
36443654

36453655
const int mmq_x_max = get_mmq_x_max_host(cc);
36463656
const int mmq_y = get_mmq_y_host(cc);
@@ -3651,7 +3661,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
36513661
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
36523662
const int granularity = mmq_get_granularity_host(type, mmq_x, cc);
36533663

3654-
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type, warp_size>(mmq_x, mmq_y, cc, nwarps) > smpbo) {
3664+
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
36553665
continue;
36563666
}
36573667

0 commit comments

Comments
 (0)