Skip to content

Commit dd79d3c

Browse files
authored
[CUDA] Faster rms norm for small dimension (#2838)
1 parent 704fd1a commit dd79d3c

File tree

1 file changed

+236
-42
lines changed

1 file changed

+236
-42
lines changed

mlx/backend/cuda/rms_norm.cu

Lines changed: 236 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,81 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
2222
}
2323

2424
// Similar to cub::BlockReduce, but result is broadcasted to every thread.
25-
template <typename T, int BLOCK_DIM>
25+
template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE>
2626
struct BlockBroadcastReduce {
27-
static_assert(WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
28-
static_assert(BLOCK_DIM % WARP_SIZE == 0);
29-
using TempStorage = T[BLOCK_DIM / WARP_SIZE];
27+
using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1)];
3028

3129
cg::thread_block& block;
3230
TempStorage& temp;
3331

3432
template <typename Op>
3533
__device__ T Reduce(const T& input, const Op& op, const T& init_value) {
36-
auto warp = cg::tiled_partition<WARP_SIZE>(block);
34+
auto warp = cg::tiled_partition<GROUP_DIM>(block);
3735
T x = cg::reduce(warp, input, op);
38-
if (warp.thread_rank() == 0) {
39-
temp[warp.meta_group_rank()] = x;
36+
if constexpr (BLOCK_DIM > GROUP_DIM) {
37+
if (warp.thread_rank() == 0) {
38+
temp[warp.meta_group_rank()] = x;
39+
}
40+
block.sync();
41+
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
42+
: init_value;
43+
return cg::reduce(warp, x, op);
44+
} else {
45+
return x;
4046
}
41-
block.sync();
42-
x = warp.thread_rank() < warp.meta_group_size() ? temp[warp.thread_rank()]
43-
: init_value;
44-
return cg::reduce(warp, x, op);
4547
}
4648

4749
__device__ T Sum(const T& input) {
4850
return Reduce(input, cg::plus<T>{}, T{});
4951
}
5052
};
5153

54+
template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4>
55+
__global__ void rms_norm_small(
56+
const T* x,
57+
const T* w,
58+
T* out,
59+
float eps,
60+
uint32_t axis_size,
61+
uint32_t n_rows,
62+
int64_t w_stride) {
63+
auto grid = cg::this_grid();
64+
auto block = cg::this_thread_block();
65+
66+
using BlockReduceT = BlockBroadcastReduce<float, BLOCK_DIM, REDUCE_DIM>;
67+
__shared__ typename BlockReduceT::TempStorage temp;
68+
69+
auto row =
70+
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
71+
if (row >= n_rows) {
72+
return;
73+
}
74+
x += row * axis_size;
75+
out += row * axis_size;
76+
77+
// Normalizer.
78+
float normalizer = 0;
79+
auto index = block.thread_index().x;
80+
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
81+
#pragma unroll
82+
for (int i = 0; i < N_READS; ++i) {
83+
float t = static_cast<float>(xn[i]);
84+
normalizer += t * t;
85+
}
86+
87+
normalizer = BlockReduceT{block, temp}.Sum(normalizer);
88+
normalizer = rsqrt(normalizer / axis_size + eps);
89+
90+
// Outputs.
91+
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
92+
#pragma unroll
93+
for (int i = 0; i < N_READS; ++i) {
94+
float y = static_cast<float>(xn[i]) * normalizer;
95+
xn[i] = wn[i] * static_cast<T>(y);
96+
}
97+
store_vector<N_READS>(out, index, xn, axis_size);
98+
}
99+
52100
template <typename T, int BLOCK_DIM, int N_READS = 4>
53101
__global__ void rms_norm(
54102
const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
94142
}
95143
}
96144

145+
template <
146+
typename T,
147+
bool HAS_W,
148+
int BLOCK_DIM,
149+
int REDUCE_DIM,
150+
int N_READS = 4>
151+
__global__ void rms_norm_vjp_small(
152+
const T* x,
153+
const T* w,
154+
const T* g,
155+
T* gx,
156+
T* gw,
157+
float eps,
158+
int32_t axis_size,
159+
int32_t n_rows,
160+
int64_t w_stride) {
161+
auto grid = cg::this_grid();
162+
auto block = cg::this_thread_block();
163+
164+
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM, REDUCE_DIM>;
165+
__shared__ typename BlockReduceF2::TempStorage temp;
166+
167+
auto row =
168+
(grid.block_rank() * block.dim_threads().y) + block.thread_index().y;
169+
if (row >= n_rows) {
170+
return;
171+
}
172+
173+
x += row * axis_size;
174+
g += row * axis_size;
175+
gx += row * axis_size;
176+
gw += row * axis_size;
177+
178+
// Normalizer.
179+
float2 factors = {};
180+
auto index = block.thread_index().x;
181+
auto xn = load_vector<N_READS>(x, index, axis_size, T(0));
182+
auto gn = load_vector<N_READS>(g, index, axis_size, T(0));
183+
auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T(0));
184+
for (int i = 0; i < N_READS; i++) {
185+
float t = static_cast<float>(xn[i]);
186+
float wi = wn[i];
187+
float gi = gn[i];
188+
float wg = wi * gi;
189+
factors = plus_f2(factors, {wg * t, t * t});
190+
}
191+
192+
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
193+
float meangwx = factors.x / axis_size;
194+
float normalizer = rsqrt(factors.y / axis_size + eps);
195+
float normalizer3 = normalizer * normalizer * normalizer;
196+
197+
// Outputs.
198+
for (int i = 0; i < N_READS; i++) {
199+
float xi = xn[i];
200+
float wi = wn[i];
201+
float gi = gn[i];
202+
xn[i] = static_cast<T>(normalizer * wi * gi - xi * meangwx * normalizer3);
203+
if constexpr (HAS_W) {
204+
wn[i] = static_cast<T>(gi * xi * normalizer);
205+
}
206+
}
207+
store_vector<N_READS>(gx, index, xn, axis_size);
208+
if constexpr (HAS_W) {
209+
store_vector<N_READS>(gw, index, wn, axis_size);
210+
}
211+
}
212+
97213
template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4>
98214
__global__ void rms_norm_vjp(
99215
const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
107223
auto grid = cg::this_grid();
108224
auto block = cg::this_thread_block();
109225

110-
using BlockReduceF = BlockBroadcastReduce<float, BLOCK_DIM>;
111226
using BlockReduceF2 = BlockBroadcastReduce<float2, BLOCK_DIM>;
112-
__shared__ union {
113-
typename BlockReduceF::TempStorage f;
114-
typename BlockReduceF2::TempStorage f2;
115-
} temp;
227+
__shared__ typename BlockReduceF2::TempStorage temp;
116228

117229
x += grid.block_rank() * axis_size;
118230
g += grid.block_rank() * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
134246
factors = plus_f2(factors, {wg * t, t * t});
135247
}
136248
}
137-
factors = BlockReduceF2{block, temp.f2}.Reduce(factors, plus_f2, {});
249+
factors = BlockReduceF2{block, temp}.Reduce(factors, plus_f2, {});
138250
float meangwx = factors.x / axis_size;
139251
float normalizer = rsqrt(factors.y / axis_size + eps);
140252
float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
169281
return s.device == Device::cpu;
170282
}
171283

284+
template <int n_per_thread, typename F>
285+
void dispatch_group_dim(int axis_size, F&& f) {
286+
if (axis_size <= n_per_thread * 8) {
287+
f(std::integral_constant<int, 8>{},
288+
std::integral_constant<int, 1>(),
289+
std::integral_constant<int, 16>());
290+
} else if (axis_size <= n_per_thread * 16) {
291+
f(std::integral_constant<int, 16>{},
292+
std::integral_constant<int, 1>(),
293+
std::integral_constant<int, 8>());
294+
} else if (axis_size <= n_per_thread * 32) {
295+
f(std::integral_constant<int, 32>{},
296+
std::integral_constant<int, 1>(),
297+
std::integral_constant<int, 4>());
298+
} else if (axis_size <= n_per_thread * 32 * 2) {
299+
f(std::integral_constant<int, 32>{},
300+
std::integral_constant<int, 2>(),
301+
std::integral_constant<int, 2>());
302+
} else if (axis_size <= n_per_thread * 32 * 4) {
303+
f(std::integral_constant<int, 32>{},
304+
std::integral_constant<int, 4>(),
305+
std::integral_constant<int, 1>());
306+
} else if (axis_size <= n_per_thread * 32 * 8) {
307+
f(std::integral_constant<int, 32>{},
308+
std::integral_constant<int, 8>(),
309+
std::integral_constant<int, 1>());
310+
} else if (axis_size <= n_per_thread * 32 * 16) {
311+
f(std::integral_constant<int, 32>{},
312+
std::integral_constant<int, 16>(),
313+
std::integral_constant<int, 1>());
314+
} else {
315+
f(std::integral_constant<int, 32>{},
316+
std::integral_constant<int, 32>(),
317+
std::integral_constant<int, 1>());
318+
}
319+
}
320+
172321
// TODO: There are duplicate code with backend/metal/normalization.cpp
173322
void RMSNorm::eval_gpu(
174323
const std::vector<array>& inputs,
@@ -216,20 +365,41 @@ void RMSNorm::eval_gpu(
216365
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
217366
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
218367
constexpr int N_READS = 16 / sizeof(DataType);
219-
dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
220-
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
368+
if (axis_size <= N_READS * 1024) {
369+
dispatch_group_dim<N_READS>(
370+
axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
371+
constexpr int block_dim = n_groups() * group_dim();
372+
auto kernel =
373+
cu::rms_norm_small<DataType, block_dim, group_dim(), N_READS>;
374+
auto n_blocks =
375+
(n_rows + groups_per_block() - 1) / groups_per_block();
376+
encoder.add_kernel_node(
377+
kernel,
378+
n_blocks,
379+
{block_dim, groups_per_block()},
380+
0,
381+
gpu_ptr<DataType>(x),
382+
gpu_ptr<DataType>(w),
383+
gpu_ptr<DataType>(out),
384+
eps_,
385+
axis_size,
386+
n_rows,
387+
w_stride);
388+
});
389+
} else {
390+
auto kernel = cu::rms_norm<DataType, 1024, N_READS>;
221391
encoder.add_kernel_node(
222392
kernel,
223393
n_rows,
224-
block_dim(),
394+
1024,
225395
0,
226396
gpu_ptr<DataType>(x),
227397
gpu_ptr<DataType>(w),
228398
gpu_ptr<DataType>(out),
229399
eps_,
230400
axis_size,
231401
w_stride);
232-
});
402+
}
233403
});
234404
}
235405

@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
306476
dispatch_bool(has_w, [&](auto has_w_constant) {
307477
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
308478
constexpr int N_READS = 16 / sizeof(DataType);
309-
dispatch_block_dim(
310-
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
311-
auto kernel = cu::rms_norm_vjp<
312-
DataType,
313-
has_w_constant.value,
314-
block_dim(),
315-
N_READS>;
316-
encoder.add_kernel_node(
317-
kernel,
318-
n_rows,
319-
block_dim(),
320-
0,
321-
gpu_ptr<DataType>(x),
322-
gpu_ptr<DataType>(w),
323-
gpu_ptr<DataType>(g),
324-
gpu_ptr<DataType>(gx),
325-
gpu_ptr<DataType>(gw_temp),
326-
eps_,
327-
axis_size,
328-
w_stride);
329-
});
479+
if (axis_size <= N_READS * 1024) {
480+
dispatch_group_dim<N_READS>(
481+
axis_size,
482+
[&](auto group_dim, auto n_groups, auto groups_per_block) {
483+
constexpr int block_dim = group_dim() * n_groups();
484+
auto kernel = cu::rms_norm_vjp_small<
485+
DataType,
486+
has_w_constant.value,
487+
block_dim,
488+
group_dim(),
489+
N_READS>;
490+
auto n_blocks =
491+
(n_rows + groups_per_block() - 1) / groups_per_block();
492+
encoder.add_kernel_node(
493+
kernel,
494+
n_blocks,
495+
{block_dim, groups_per_block()},
496+
0,
497+
gpu_ptr<DataType>(x),
498+
gpu_ptr<DataType>(w),
499+
gpu_ptr<DataType>(g),
500+
gpu_ptr<DataType>(gx),
501+
gpu_ptr<DataType>(gw_temp),
502+
eps_,
503+
axis_size,
504+
n_rows,
505+
w_stride);
506+
});
507+
} else {
508+
auto kernel =
509+
cu::rms_norm_vjp<DataType, has_w_constant.value, 1024, N_READS>;
510+
encoder.add_kernel_node(
511+
kernel,
512+
n_rows,
513+
1024,
514+
0,
515+
gpu_ptr<DataType>(x),
516+
gpu_ptr<DataType>(w),
517+
gpu_ptr<DataType>(g),
518+
gpu_ptr<DataType>(gx),
519+
gpu_ptr<DataType>(gw_temp),
520+
eps_,
521+
axis_size,
522+
w_stride);
523+
}
330524
});
331525
});
332526

0 commit comments

Comments
 (0)