@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
57
57
T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
58
58
59
59
if constexpr (has_ids) {
60
- __shared__ int has_any;
61
- if (threadIdx .y == 0 ) {
62
- int local_has_any = 0 ;
63
- for (int j = threadIdx .x ; j < cols_per_block; j += warp_size) {
64
- int slot = -1 ;
65
- for (int k = 0 ; k < nchannels_dst; ++k) {
66
- const int idv = ids[j*stride_row_id + k*stride_col_id];
67
- if (idv == expert_idx) {
68
- slot = k;
69
- break ;
70
- }
71
- }
72
- if (j < cols_per_block) {
73
- local_has_any |= (slot >= 0 );
74
- slot_map[j] = slot;
60
+ int found = 0 ;
61
+
62
+ for (int j0 = 0 ; j0 < cols_per_block; j0 += nwarps) {
63
+ const int j = j0 + threadIdx .y ;
64
+ const int32_t * __restrict__ id_row = ids + j*stride_row_id;
65
+
66
+ if (threadIdx .x == 0 ) {
67
+ slot_map[j] = -1 ;
68
+ }
69
+
70
+ for (int k = threadIdx .x ; k < nchannels_dst; k += warp_size) {
71
+ int match = id_row[k*stride_col_id] == expert_idx;
72
+
73
+ if (match) {
74
+ slot_map[j] = k;
75
+ found = 1 ;
76
+ break ;
75
77
}
76
78
}
77
- has_any = warp_reduce_any (local_has_any);
78
79
}
79
- __syncthreads ();
80
- if (has_any == 0 ) {
80
+
81
+ if (! __syncthreads_or (found) ) {
81
82
return ;
82
83
}
83
84
}
84
85
86
+
85
87
for (int col = threadIdx .y *warp_size + threadIdx .x ; col < ncols; col += nwarps*warp_size) {
86
88
tile_A A[ntA][warp_size / tile_A::J];
87
89
#pragma unroll
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
106
108
if constexpr (!has_ids) {
107
109
tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[j*stride_col_y + col] : 0 .0f ;
108
110
} else {
109
- float val = 0 .0f ;
110
- if (j < cols_per_block) {
111
- const int slot = slot_map[j];
112
- if (slot >= 0 ) {
113
- val = y[slot*stride_channel_y + j*stride_col_y + col];
114
- }
115
- }
116
- tile_xy[j0*tile_k_padded + threadIdx .x ] = val;
111
+ tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0 .0f ;
117
112
}
118
113
}
119
114
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
125
120
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
126
121
tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
127
122
} else {
128
- float2 tmp = make_float2 (0 .0f , 0 .0f );
129
- if (j < cols_per_block) {
130
- const int slot = slot_map[j];
131
- if (slot >= 0 ) {
132
- const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
133
- tmp = y2_slot[j*stride_col_y + col];
134
- }
135
- }
123
+ float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2 *) &y[slot_map[j]*stride_channel_y + 2 *(j*stride_col_y + col)] : make_float2 (0 .0f , 0 .0f );
136
124
tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
137
125
}
138
126
}
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
221
209
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
222
210
if (ids) {
223
211
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
224
- (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
212
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
225
213
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
226
214
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
227
215
} else {
0 commit comments