1616
1717#pragma once
1818
19- #include < stdio.h>
2019#include < assert.h>
2120#include < cuda_fp16.h>
2221#include < cfloat>
@@ -91,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
9190 }
9291}
9392
93+
94+ /*
95+ * Extended softmax (from native aten pytorch) with following additional features
96+ * 1) input scaling
97+ */
98+ template <typename input_t , typename output_t , typename acc_t , int log2_elements>
99+ __global__ void scaled_softmax_warp_forward (
100+ output_t *dst,
101+ const input_t *src,
102+ const acc_t scale,
103+ int micro_batch_size,
104+ int element_count)
105+ {
106+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
107+ // warp_size of method warp_softmax_forward_kernel.
108+ constexpr int next_power_of_two = 1 << log2_elements;
109+ constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
110+ constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
111+ constexpr int WARP_BATCH = (next_power_of_two <= 128 ) ? 2 : 1 ;
112+ constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4 ) ? 1 : 4 ;
113+
114+ // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
115+ // gridDim/blockIdx = (seq_len, attn_heads, batches)
116+ int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z ))+ threadIdx.y ) * WARP_BATCH;
117+
118+ // micro_batch_size might not be a multiple of WARP_BATCH. Check how
119+ // many batches have to computed within this WARP.
120+ int local_batches = micro_batch_size - first_batch;
121+ if (local_batches > WARP_BATCH)
122+ local_batches = WARP_BATCH;
123+
124+ // there might be multiple batches per warp. compute the index within the batch
125+ int local_idx = threadIdx.x ;
126+
127+ src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
128+ dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
129+
130+ // load data from global memory
131+ acc_t elements[WARP_BATCH][WARP_ITERATIONS];
132+ input_t temp_data[ELEMENTS_PER_LDG_STG];
133+ #pragma unroll
134+ for (int i = 0 ; i < WARP_BATCH; ++i) {
135+ int batch_element_count = (i >= local_batches) ? 0 : element_count;
136+
137+ #pragma unroll
138+ for (int it = 0 ; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
139+ int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
140+
141+ if (element_index < batch_element_count) {
142+ int itr_idx = i*element_count+it*WARP_SIZE;
143+ copy_vector<input_t , ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
144+
145+ #pragma unroll
146+ for (int element = 0 ; element < ELEMENTS_PER_LDG_STG; ++element) {
147+ elements[i][it + element] = (acc_t )temp_data[element] * scale;
148+ }
149+ } else {
150+ #pragma unroll
151+ for (int element = 0 ; element < ELEMENTS_PER_LDG_STG; ++element) {
152+ elements[i][it + element] = -std::numeric_limits<acc_t >::infinity ();
153+ }
154+ }
155+ }
156+ }
157+
158+ // compute max_value
159+ acc_t max_value[WARP_BATCH];
160+ #pragma unroll
161+ for (int i = 0 ; i < WARP_BATCH; ++i) {
162+ max_value[i] = elements[i][0 ];
163+ #pragma unroll
164+ for (int it = 1 ; it < WARP_ITERATIONS; ++it) {
165+ max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
166+ }
167+ }
168+ warp_reduce<acc_t , WARP_BATCH, WARP_SIZE, Max>(max_value);
169+
170+ acc_t sum[WARP_BATCH] { 0 .0f };
171+ #pragma unroll
172+ for (int i = 0 ; i < WARP_BATCH; ++i) {
173+ #pragma unroll
174+ for (int it = 0 ; it < WARP_ITERATIONS; ++it) {
175+ elements[i][it] = std::exp ((elements[i][it] - max_value[i]));
176+ sum[i] += elements[i][it];
177+ }
178+ }
179+ warp_reduce<acc_t , WARP_BATCH, WARP_SIZE, Add>(sum);
180+
181+ // store result
182+ output_t out[ELEMENTS_PER_LDG_STG];
183+ #pragma unroll
184+ for (int i = 0 ; i < WARP_BATCH; ++i) {
185+ if (i >= local_batches)
186+ break ;
187+ #pragma unroll
188+ for (int it = 0 ; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
189+ int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
190+ if (element_index < element_count) {
191+ #pragma unroll
192+ for (int element = 0 ; element < ELEMENTS_PER_LDG_STG; ++element) {
193+ out[element] = elements[i][it + element] / sum[i];
194+ }
195+ copy_vector<output_t , ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
196+ } else {
197+ break ;
198+ }
199+ }
200+ }
201+ }
202+
203+
94204/*
95205 * Extended softmax (from native aten pytorch) with following additional features
96206 * 1) input scaling
@@ -112,7 +222,7 @@ __global__ void scaled_masked_softmax_warp_forward(
112222 constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
113223 constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
114224 constexpr int WARP_BATCH = (next_power_of_two <= 128 ) ? 2 : 1 ;
115- constexpr int ELEMENTS_PER_LDG_STG = 4 ;
225+ constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4 ) ? 1 : 4 ;
116226
117227 // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
118228 // gridDim/blockIdx = (seq_len, attn_heads, batches)
@@ -231,7 +341,7 @@ __global__ void scaled_masked_softmax_warp_backward(
231341 constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
232342 constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
233343 constexpr int WARP_BATCH = (next_power_of_two <= 128 ) ? 2 : 1 ;
234- constexpr int ELEMENTS_PER_LDG_STG = 4 ;
344+ constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4 ) ? 1 : 4 ;
235345
236346 // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
237347 // gridDim/blockIdx = (seq_len, attn_heads, batches)
@@ -317,7 +427,6 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
317427 int log2_elements = log2_ceil (key_seq_len);
318428 const int next_power_of_two = 1 << log2_elements;
319429
320- int batch_count = batches * attn_heads * query_seq_len;
321430 int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
322431 int batches_per_warp = (next_power_of_two <= 128 ) ? 2 : 1 ;
323432
@@ -328,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
328437 return batches_per_block;
329438}
330439
440+ template <typename input_t , typename output_t , typename acc_t >
441+ void dispatch_scaled_softmax_forward (
442+ output_t *dst,
443+ const input_t *src,
444+ const input_t scale,
445+ int query_seq_len,
446+ int key_seq_len,
447+ int batches,
448+ int attn_heads)
449+ {
450+ TORCH_INTERNAL_ASSERT (key_seq_len >= 0 && key_seq_len <= 4096 );
451+ if (key_seq_len == 0 ) {
452+ return ;
453+ } else {
454+ int log2_elements = log2_ceil (key_seq_len);
455+ const int next_power_of_two = 1 << log2_elements;
456+ int batch_count = batches * attn_heads * query_seq_len;
457+
458+ // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
459+ int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
460+
461+ // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
462+ int batches_per_warp = (next_power_of_two <= 128 ) ? 2 : 1 ;
463+
464+ // use 128 threads per block to maximimize gpu utilization
465+ constexpr int threads_per_block = 128 ;
466+
467+ int warps_per_block = (threads_per_block / warp_size);
468+ int batches_per_block = warps_per_block * batches_per_warp;
469+ TORCH_INTERNAL_ASSERT (query_seq_len%batches_per_block == 0 );
470+ dim3 blocks (query_seq_len/batches_per_block, attn_heads, batches);
471+ dim3 threads (warp_size, warps_per_block, 1 );
472+ // Launch code would be more elegant if C++ supported FOR CONSTEXPR
473+ switch (log2_elements) {
474+ case 0 : // 1
475+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 0 >
476+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
477+ break ;
478+ case 1 : // 2
479+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 1 >
480+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
481+ break ;
482+ case 2 : // 4
483+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 2 >
484+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
485+ break ;
486+ case 3 : // 8
487+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 3 >
488+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
489+ break ;
490+ case 4 : // 16
491+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 4 >
492+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
493+ break ;
494+ case 5 : // 32
495+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 5 >
496+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
497+ break ;
498+ case 6 : // 64
499+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 6 >
500+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
501+ break ;
502+ case 7 : // 128
503+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 7 >
504+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
505+ break ;
506+ case 8 : // 256
507+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 8 >
508+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
509+ break ;
510+ case 9 : // 512
511+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 9 >
512+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
513+ break ;
514+ case 10 : // 1024
515+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 10 >
516+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
517+ break ;
518+ case 11 : // 2048
519+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 11 >
520+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
521+ break ;
522+ case 12 : // 4096
523+ scaled_softmax_warp_forward<input_t , output_t , acc_t , 12 >
524+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, scale, batch_count, key_seq_len);
525+ break ;
526+ default :
527+ break ;
528+ }
529+ }
530+ }
531+
331532template <typename input_t , typename output_t , typename acc_t >
332533void dispatch_scaled_masked_softmax_forward (
333534 output_t *dst,
@@ -340,6 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
340541 int attn_heads,
341542 int pad_batches)
342543{
544+ TORCH_INTERNAL_ASSERT (key_seq_len >= 0 && key_seq_len <= 4096 );
343545 if (key_seq_len == 0 ) {
344546 return ;
345547 } else {
@@ -358,6 +560,7 @@ void dispatch_scaled_masked_softmax_forward(
358560
359561 int warps_per_block = (threads_per_block / warp_size);
360562 int batches_per_block = warps_per_block * batches_per_warp;
563+ TORCH_INTERNAL_ASSERT (query_seq_len%batches_per_block == 0 );
361564 dim3 blocks (query_seq_len/batches_per_block, attn_heads, batches);
362565 dim3 threads (warp_size, warps_per_block, 1 );
363566 // Launch code would be more elegant if C++ supported FOR CONSTEXPR
@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
410613 scaled_masked_softmax_warp_forward<input_t , output_t , acc_t , 11 >
411614 <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
412615 break ;
616+ case 12 : // 4096
617+ scaled_masked_softmax_warp_forward<input_t , output_t , acc_t , 12 >
618+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
619+ break ;
413620 default :
414621 break ;
415622 }
@@ -427,6 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
427634 int batches,
428635 int attn_heads)
429636{
637+ TORCH_INTERNAL_ASSERT ( key_seq_len >= 0 && key_seq_len <= 4096 );
430638 if (key_seq_len == 0 ) {
431639 return ;
432640 } else {
@@ -497,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
497705 scaled_masked_softmax_warp_backward<input_t , output_t , acc_t , 11 >
498706 <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
499707 break ;
708+ case 12 : // 4096
709+ scaled_masked_softmax_warp_backward<input_t , output_t , acc_t , 12 >
710+ <<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream ()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
711+ break ;
712+
500713 default :
501714 break ;
502715 }
0 commit comments