Skip to content

Commit 1062205

Browse files
authored
CUDA: some micro-optimizations in mmf.cuh for mul_mat_id (ggml-org#15926)
1 parent a68f31e commit 1062205

File tree

1 file changed

+23
-35
lines changed

1 file changed

+23
-35
lines changed

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
5757
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
5858

5959
if constexpr (has_ids) {
60-
__shared__ int has_any;
61-
if (threadIdx.y == 0) {
62-
int local_has_any = 0;
63-
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
64-
int slot = -1;
65-
for (int k = 0; k < nchannels_dst; ++k) {
66-
const int idv = ids[j*stride_row_id + k*stride_col_id];
67-
if (idv == expert_idx) {
68-
slot = k;
69-
break;
70-
}
71-
}
72-
if (j < cols_per_block) {
73-
local_has_any |= (slot >= 0);
74-
slot_map[j] = slot;
60+
int found = 0;
61+
62+
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
63+
const int j = j0 + threadIdx.y;
64+
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
65+
66+
if (threadIdx.x == 0) {
67+
slot_map[j] = -1;
68+
}
69+
70+
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
71+
int match = id_row[k*stride_col_id] == expert_idx;
72+
73+
if (match) {
74+
slot_map[j] = k;
75+
found = 1;
76+
break;
7577
}
7678
}
77-
has_any = warp_reduce_any(local_has_any);
7879
}
79-
__syncthreads();
80-
if (has_any == 0) {
80+
81+
if (!__syncthreads_or(found)) {
8182
return;
8283
}
8384
}
8485

86+
8587
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
8688
tile_A A[ntA][warp_size / tile_A::J];
8789
#pragma unroll
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
106108
if constexpr (!has_ids) {
107109
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
108110
} else {
109-
float val = 0.0f;
110-
if (j < cols_per_block) {
111-
const int slot = slot_map[j];
112-
if (slot >= 0) {
113-
val = y[slot*stride_channel_y + j*stride_col_y + col];
114-
}
115-
}
116-
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
111+
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
117112
}
118113
}
119114
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
125120
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
126121
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
127122
} else {
128-
float2 tmp = make_float2(0.0f, 0.0f);
129-
if (j < cols_per_block) {
130-
const int slot = slot_map[j];
131-
if (slot >= 0) {
132-
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
133-
tmp = y2_slot[j*stride_col_y + col];
134-
}
135-
}
123+
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
136124
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
137125
}
138126
}
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
221209
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
222210
if (ids) {
223211
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
224-
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212+
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
225213
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
226214
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
227215
} else {

0 commit comments

Comments
 (0)