@@ -22,33 +22,81 @@ inline __device__ float2 plus_f2(const float2& a, const float2& b) {
2222}
2323
2424// Similar to cub::BlockReduce, but result is broadcasted to every thread.
25- template <typename T, int BLOCK_DIM>
25+ template <typename T, int BLOCK_DIM, int GROUP_DIM = WARP_SIZE >
2626struct BlockBroadcastReduce {
27- static_assert (WARP_SIZE <= BLOCK_DIM && BLOCK_DIM <= WARP_SIZE * WARP_SIZE);
28- static_assert (BLOCK_DIM % WARP_SIZE == 0 );
29- using TempStorage = T[BLOCK_DIM / WARP_SIZE];
27+ using TempStorage = T[std::max(BLOCK_DIM / WARP_SIZE, 1 )];
3028
3129 cg::thread_block& block;
3230 TempStorage& temp;
3331
3432 template <typename Op>
3533 __device__ T Reduce (const T& input, const Op& op, const T& init_value) {
36- auto warp = cg::tiled_partition<WARP_SIZE >(block);
34+ auto warp = cg::tiled_partition<GROUP_DIM >(block);
3735 T x = cg::reduce (warp, input, op);
38- if (warp.thread_rank () == 0 ) {
39- temp[warp.meta_group_rank ()] = x;
36+ if constexpr (BLOCK_DIM > GROUP_DIM) {
37+ if (warp.thread_rank () == 0 ) {
38+ temp[warp.meta_group_rank ()] = x;
39+ }
40+ block.sync ();
41+ x = warp.thread_rank () < warp.meta_group_size () ? temp[warp.thread_rank ()]
42+ : init_value;
43+ return cg::reduce (warp, x, op);
44+ } else {
45+ return x;
4046 }
41- block.sync ();
42- x = warp.thread_rank () < warp.meta_group_size () ? temp[warp.thread_rank ()]
43- : init_value;
44- return cg::reduce (warp, x, op);
4547 }
4648
4749 __device__ T Sum (const T& input) {
4850 return Reduce (input, cg::plus<T>{}, T{});
4951 }
5052};
5153
54+ template <typename T, int BLOCK_DIM, int REDUCE_DIM, int N_READS = 4 >
55+ __global__ void rms_norm_small (
56+ const T* x,
57+ const T* w,
58+ T* out,
59+ float eps,
60+ uint32_t axis_size,
61+ uint32_t n_rows,
62+ int64_t w_stride) {
63+ auto grid = cg::this_grid ();
64+ auto block = cg::this_thread_block ();
65+
66+ using BlockReduceT = BlockBroadcastReduce<float , BLOCK_DIM, REDUCE_DIM>;
67+ __shared__ typename BlockReduceT::TempStorage temp;
68+
69+ auto row =
70+ (grid.block_rank () * block.dim_threads ().y ) + block.thread_index ().y ;
71+ if (row >= n_rows) {
72+ return ;
73+ }
74+ x += row * axis_size;
75+ out += row * axis_size;
76+
77+ // Normalizer.
78+ float normalizer = 0 ;
79+ auto index = block.thread_index ().x ;
80+ auto xn = load_vector<N_READS>(x, index, axis_size, T (0 ));
81+ #pragma unroll
82+ for (int i = 0 ; i < N_READS; ++i) {
83+ float t = static_cast <float >(xn[i]);
84+ normalizer += t * t;
85+ }
86+
87+ normalizer = BlockReduceT{block, temp}.Sum (normalizer);
88+ normalizer = rsqrt (normalizer / axis_size + eps);
89+
90+ // Outputs.
91+ auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T (0 ));
92+ #pragma unroll
93+ for (int i = 0 ; i < N_READS; ++i) {
94+ float y = static_cast <float >(xn[i]) * normalizer;
95+ xn[i] = wn[i] * static_cast <T>(y);
96+ }
97+ store_vector<N_READS>(out, index, xn, axis_size);
98+ }
99+
52100template <typename T, int BLOCK_DIM, int N_READS = 4 >
53101__global__ void rms_norm (
54102 const T* x,
@@ -94,6 +142,74 @@ __global__ void rms_norm(
94142 }
95143}
96144
145+ template <
146+ typename T,
147+ bool HAS_W,
148+ int BLOCK_DIM,
149+ int REDUCE_DIM,
150+ int N_READS = 4 >
151+ __global__ void rms_norm_vjp_small (
152+ const T* x,
153+ const T* w,
154+ const T* g,
155+ T* gx,
156+ T* gw,
157+ float eps,
158+ int32_t axis_size,
159+ int32_t n_rows,
160+ int64_t w_stride) {
161+ auto grid = cg::this_grid ();
162+ auto block = cg::this_thread_block ();
163+
164+ using BlockReduceF2 = BlockBroadcastReduce<float2 , BLOCK_DIM, REDUCE_DIM>;
165+ __shared__ typename BlockReduceF2::TempStorage temp;
166+
167+ auto row =
168+ (grid.block_rank () * block.dim_threads ().y ) + block.thread_index ().y ;
169+ if (row >= n_rows) {
170+ return ;
171+ }
172+
173+ x += row * axis_size;
174+ g += row * axis_size;
175+ gx += row * axis_size;
176+ gw += row * axis_size;
177+
178+ // Normalizer.
179+ float2 factors = {};
180+ auto index = block.thread_index ().x ;
181+ auto xn = load_vector<N_READS>(x, index, axis_size, T (0 ));
182+ auto gn = load_vector<N_READS>(g, index, axis_size, T (0 ));
183+ auto wn = load_vector<N_READS>(w, index, axis_size, w_stride, T (0 ));
184+ for (int i = 0 ; i < N_READS; i++) {
185+ float t = static_cast <float >(xn[i]);
186+ float wi = wn[i];
187+ float gi = gn[i];
188+ float wg = wi * gi;
189+ factors = plus_f2 (factors, {wg * t, t * t});
190+ }
191+
192+ factors = BlockReduceF2{block, temp}.Reduce (factors, plus_f2, {});
193+ float meangwx = factors.x / axis_size;
194+ float normalizer = rsqrt (factors.y / axis_size + eps);
195+ float normalizer3 = normalizer * normalizer * normalizer;
196+
197+ // Outputs.
198+ for (int i = 0 ; i < N_READS; i++) {
199+ float xi = xn[i];
200+ float wi = wn[i];
201+ float gi = gn[i];
202+ xn[i] = static_cast <T>(normalizer * wi * gi - xi * meangwx * normalizer3);
203+ if constexpr (HAS_W) {
204+ wn[i] = static_cast <T>(gi * xi * normalizer);
205+ }
206+ }
207+ store_vector<N_READS>(gx, index, xn, axis_size);
208+ if constexpr (HAS_W) {
209+ store_vector<N_READS>(gw, index, wn, axis_size);
210+ }
211+ }
212+
97213template <typename T, bool HAS_W, int BLOCK_DIM, int N_READS = 4 >
98214__global__ void rms_norm_vjp (
99215 const T* x,
@@ -107,12 +223,8 @@ __global__ void rms_norm_vjp(
107223 auto grid = cg::this_grid ();
108224 auto block = cg::this_thread_block ();
109225
110- using BlockReduceF = BlockBroadcastReduce<float , BLOCK_DIM>;
111226 using BlockReduceF2 = BlockBroadcastReduce<float2 , BLOCK_DIM>;
112- __shared__ union {
113- typename BlockReduceF::TempStorage f;
114- typename BlockReduceF2::TempStorage f2;
115- } temp;
227+ __shared__ typename BlockReduceF2::TempStorage temp;
116228
117229 x += grid.block_rank () * axis_size;
118230 g += grid.block_rank () * axis_size;
@@ -134,7 +246,7 @@ __global__ void rms_norm_vjp(
134246 factors = plus_f2 (factors, {wg * t, t * t});
135247 }
136248 }
137- factors = BlockReduceF2{block, temp. f2 }.Reduce (factors, plus_f2, {});
249+ factors = BlockReduceF2{block, temp}.Reduce (factors, plus_f2, {});
138250 float meangwx = factors.x / axis_size;
139251 float normalizer = rsqrt (factors.y / axis_size + eps);
140252 float normalizer3 = normalizer * normalizer * normalizer;
@@ -169,6 +281,43 @@ bool RMSNorm::use_fallback(Stream s) {
169281 return s.device == Device::cpu;
170282}
171283
284+ template <int n_per_thread, typename F>
285+ void dispatch_group_dim (int axis_size, F&& f) {
286+ if (axis_size <= n_per_thread * 8 ) {
287+ f (std::integral_constant<int , 8 >{},
288+ std::integral_constant<int , 1 >(),
289+ std::integral_constant<int , 16 >());
290+ } else if (axis_size <= n_per_thread * 16 ) {
291+ f (std::integral_constant<int , 16 >{},
292+ std::integral_constant<int , 1 >(),
293+ std::integral_constant<int , 8 >());
294+ } else if (axis_size <= n_per_thread * 32 ) {
295+ f (std::integral_constant<int , 32 >{},
296+ std::integral_constant<int , 1 >(),
297+ std::integral_constant<int , 4 >());
298+ } else if (axis_size <= n_per_thread * 32 * 2 ) {
299+ f (std::integral_constant<int , 32 >{},
300+ std::integral_constant<int , 2 >(),
301+ std::integral_constant<int , 2 >());
302+ } else if (axis_size <= n_per_thread * 32 * 4 ) {
303+ f (std::integral_constant<int , 32 >{},
304+ std::integral_constant<int , 4 >(),
305+ std::integral_constant<int , 1 >());
306+ } else if (axis_size <= n_per_thread * 32 * 8 ) {
307+ f (std::integral_constant<int , 32 >{},
308+ std::integral_constant<int , 8 >(),
309+ std::integral_constant<int , 1 >());
310+ } else if (axis_size <= n_per_thread * 32 * 16 ) {
311+ f (std::integral_constant<int , 32 >{},
312+ std::integral_constant<int , 16 >(),
313+ std::integral_constant<int , 1 >());
314+ } else {
315+ f (std::integral_constant<int , 32 >{},
316+ std::integral_constant<int , 32 >(),
317+ std::integral_constant<int , 1 >());
318+ }
319+ }
320+
172321// TODO: There are duplicate code with backend/metal/normalization.cpp
173322void RMSNorm::eval_gpu (
174323 const std::vector<array>& inputs,
@@ -216,20 +365,41 @@ void RMSNorm::eval_gpu(
216365 dispatch_float_types (out.dtype (), " rms_norm" , [&](auto type_tag) {
217366 using DataType = cuda_type_t <MLX_GET_TYPE (type_tag)>;
218367 constexpr int N_READS = 16 / sizeof (DataType);
219- dispatch_block_dim (cuda::ceil_div (axis_size, N_READS), [&](auto block_dim) {
220- auto kernel = cu::rms_norm<DataType, block_dim (), N_READS>;
368+ if (axis_size <= N_READS * 1024 ) {
369+ dispatch_group_dim<N_READS>(
370+ axis_size, [&](auto group_dim, auto n_groups, auto groups_per_block) {
371+ constexpr int block_dim = n_groups () * group_dim ();
372+ auto kernel =
373+ cu::rms_norm_small<DataType, block_dim, group_dim (), N_READS>;
374+ auto n_blocks =
375+ (n_rows + groups_per_block () - 1 ) / groups_per_block ();
376+ encoder.add_kernel_node (
377+ kernel,
378+ n_blocks,
379+ {block_dim, groups_per_block ()},
380+ 0 ,
381+ gpu_ptr<DataType>(x),
382+ gpu_ptr<DataType>(w),
383+ gpu_ptr<DataType>(out),
384+ eps_,
385+ axis_size,
386+ n_rows,
387+ w_stride);
388+ });
389+ } else {
390+ auto kernel = cu::rms_norm<DataType, 1024 , N_READS>;
221391 encoder.add_kernel_node (
222392 kernel,
223393 n_rows,
224- block_dim () ,
394+ 1024 ,
225395 0 ,
226396 gpu_ptr<DataType>(x),
227397 gpu_ptr<DataType>(w),
228398 gpu_ptr<DataType>(out),
229399 eps_,
230400 axis_size,
231401 w_stride);
232- });
402+ }
233403 });
234404}
235405
@@ -306,27 +476,51 @@ void RMSNormVJP::eval_gpu(
306476 dispatch_bool (has_w, [&](auto has_w_constant) {
307477 using DataType = cuda_type_t <MLX_GET_TYPE (type_tag)>;
308478 constexpr int N_READS = 16 / sizeof (DataType);
309- dispatch_block_dim (
310- cuda::ceil_div (axis_size, N_READS), [&](auto block_dim) {
311- auto kernel = cu::rms_norm_vjp<
312- DataType,
313- has_w_constant.value ,
314- block_dim (),
315- N_READS>;
316- encoder.add_kernel_node (
317- kernel,
318- n_rows,
319- block_dim (),
320- 0 ,
321- gpu_ptr<DataType>(x),
322- gpu_ptr<DataType>(w),
323- gpu_ptr<DataType>(g),
324- gpu_ptr<DataType>(gx),
325- gpu_ptr<DataType>(gw_temp),
326- eps_,
327- axis_size,
328- w_stride);
329- });
479+ if (axis_size <= N_READS * 1024 ) {
480+ dispatch_group_dim<N_READS>(
481+ axis_size,
482+ [&](auto group_dim, auto n_groups, auto groups_per_block) {
483+ constexpr int block_dim = group_dim () * n_groups ();
484+ auto kernel = cu::rms_norm_vjp_small<
485+ DataType,
486+ has_w_constant.value ,
487+ block_dim,
488+ group_dim (),
489+ N_READS>;
490+ auto n_blocks =
491+ (n_rows + groups_per_block () - 1 ) / groups_per_block ();
492+ encoder.add_kernel_node (
493+ kernel,
494+ n_blocks,
495+ {block_dim, groups_per_block ()},
496+ 0 ,
497+ gpu_ptr<DataType>(x),
498+ gpu_ptr<DataType>(w),
499+ gpu_ptr<DataType>(g),
500+ gpu_ptr<DataType>(gx),
501+ gpu_ptr<DataType>(gw_temp),
502+ eps_,
503+ axis_size,
504+ n_rows,
505+ w_stride);
506+ });
507+ } else {
508+ auto kernel =
509+ cu::rms_norm_vjp<DataType, has_w_constant.value , 1024 , N_READS>;
510+ encoder.add_kernel_node (
511+ kernel,
512+ n_rows,
513+ 1024 ,
514+ 0 ,
515+ gpu_ptr<DataType>(x),
516+ gpu_ptr<DataType>(w),
517+ gpu_ptr<DataType>(g),
518+ gpu_ptr<DataType>(gx),
519+ gpu_ptr<DataType>(gw_temp),
520+ eps_,
521+ axis_size,
522+ w_stride);
523+ }
330524 });
331525 });
332526
0 commit comments