Skip to content

Commit 1cb76a6

Browse files
stas00hyunwoongko
andauthored
sync the whole Meg-LM fused_kernels sub-tree (#260)
* sync the whole Meg-LM fused_kernels sub-tree * author attribution * part 2 Co-authored-by: hyunwoongko <[email protected]>
1 parent 8673d46 commit 1cb76a6

File tree

4 files changed

+265
-13
lines changed

4 files changed

+265
-13
lines changed

megatron/fused_kernels/scaled_masked_softmax.h

Lines changed: 217 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#pragma once
1818

19-
#include <stdio.h>
2019
#include <assert.h>
2120
#include <cuda_fp16.h>
2221
#include <cfloat>
@@ -91,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
9190
}
9291
}
9392

93+
94+
/*
95+
* Extended softmax (from native aten pytorch) with following additional features
96+
* 1) input scaling
97+
*/
98+
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
99+
__global__ void scaled_softmax_warp_forward(
100+
output_t *dst,
101+
const input_t *src,
102+
const acc_t scale,
103+
int micro_batch_size,
104+
int element_count)
105+
{
106+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
107+
// warp_size of method warp_softmax_forward_kernel.
108+
constexpr int next_power_of_two = 1 << log2_elements;
109+
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
110+
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
111+
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
112+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
113+
114+
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
115+
// gridDim/blockIdx = (seq_len, attn_heads, batches)
116+
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
117+
118+
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
119+
// many batches have to computed within this WARP.
120+
int local_batches = micro_batch_size - first_batch;
121+
if (local_batches > WARP_BATCH)
122+
local_batches = WARP_BATCH;
123+
124+
// there might be multiple batches per warp. compute the index within the batch
125+
int local_idx = threadIdx.x;
126+
127+
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
128+
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
129+
130+
// load data from global memory
131+
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
132+
input_t temp_data[ELEMENTS_PER_LDG_STG];
133+
#pragma unroll
134+
for (int i = 0; i < WARP_BATCH; ++i) {
135+
int batch_element_count = (i >= local_batches) ? 0 : element_count;
136+
137+
#pragma unroll
138+
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
139+
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
140+
141+
if (element_index < batch_element_count) {
142+
int itr_idx = i*element_count+it*WARP_SIZE;
143+
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
144+
145+
#pragma unroll
146+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
147+
elements[i][it + element] = (acc_t)temp_data[element] * scale;
148+
}
149+
} else {
150+
#pragma unroll
151+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
152+
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
153+
}
154+
}
155+
}
156+
}
157+
158+
// compute max_value
159+
acc_t max_value[WARP_BATCH];
160+
#pragma unroll
161+
for (int i = 0; i < WARP_BATCH; ++i) {
162+
max_value[i] = elements[i][0];
163+
#pragma unroll
164+
for (int it = 1; it < WARP_ITERATIONS; ++it) {
165+
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
166+
}
167+
}
168+
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
169+
170+
acc_t sum[WARP_BATCH] { 0.0f };
171+
#pragma unroll
172+
for (int i = 0; i < WARP_BATCH; ++i) {
173+
#pragma unroll
174+
for (int it = 0; it < WARP_ITERATIONS; ++it) {
175+
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
176+
sum[i] += elements[i][it];
177+
}
178+
}
179+
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
180+
181+
// store result
182+
output_t out[ELEMENTS_PER_LDG_STG];
183+
#pragma unroll
184+
for (int i = 0; i < WARP_BATCH; ++i) {
185+
if (i >= local_batches)
186+
break;
187+
#pragma unroll
188+
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
189+
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
190+
if (element_index < element_count) {
191+
#pragma unroll
192+
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
193+
out[element] = elements[i][it + element] / sum[i];
194+
}
195+
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
196+
} else {
197+
break;
198+
}
199+
}
200+
}
201+
}
202+
203+
94204
/*
95205
* Extended softmax (from native aten pytorch) with following additional features
96206
* 1) input scaling
@@ -112,7 +222,7 @@ __global__ void scaled_masked_softmax_warp_forward(
112222
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
113223
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
114224
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
115-
constexpr int ELEMENTS_PER_LDG_STG = 4;
225+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
116226

117227
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
118228
// gridDim/blockIdx = (seq_len, attn_heads, batches)
@@ -231,7 +341,7 @@ __global__ void scaled_masked_softmax_warp_backward(
231341
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
232342
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
233343
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
234-
constexpr int ELEMENTS_PER_LDG_STG = 4;
344+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
235345

236346
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
237347
// gridDim/blockIdx = (seq_len, attn_heads, batches)
@@ -317,7 +427,6 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
317427
int log2_elements = log2_ceil(key_seq_len);
318428
const int next_power_of_two = 1 << log2_elements;
319429

320-
int batch_count = batches * attn_heads * query_seq_len;
321430
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
322431
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
323432

@@ -328,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
328437
return batches_per_block;
329438
}
330439

440+
template<typename input_t, typename output_t, typename acc_t>
441+
void dispatch_scaled_softmax_forward(
442+
output_t *dst,
443+
const input_t *src,
444+
const input_t scale,
445+
int query_seq_len,
446+
int key_seq_len,
447+
int batches,
448+
int attn_heads)
449+
{
450+
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
451+
if (key_seq_len == 0) {
452+
return;
453+
} else {
454+
int log2_elements = log2_ceil(key_seq_len);
455+
const int next_power_of_two = 1 << log2_elements;
456+
int batch_count = batches * attn_heads * query_seq_len;
457+
458+
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
459+
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
460+
461+
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
462+
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
463+
464+
// use 128 threads per block to maximimize gpu utilization
465+
constexpr int threads_per_block = 128;
466+
467+
int warps_per_block = (threads_per_block / warp_size);
468+
int batches_per_block = warps_per_block * batches_per_warp;
469+
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
470+
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
471+
dim3 threads(warp_size, warps_per_block, 1);
472+
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
473+
switch (log2_elements) {
474+
case 0: // 1
475+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
476+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
477+
break;
478+
case 1: // 2
479+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
480+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
481+
break;
482+
case 2: // 4
483+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
484+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
485+
break;
486+
case 3: // 8
487+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
488+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
489+
break;
490+
case 4: // 16
491+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
492+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
493+
break;
494+
case 5: // 32
495+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
496+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
497+
break;
498+
case 6: // 64
499+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
500+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
501+
break;
502+
case 7: // 128
503+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
504+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
505+
break;
506+
case 8: // 256
507+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
508+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
509+
break;
510+
case 9: // 512
511+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
512+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
513+
break;
514+
case 10: // 1024
515+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
516+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
517+
break;
518+
case 11: // 2048
519+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
520+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
521+
break;
522+
case 12: // 4096
523+
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
524+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
525+
break;
526+
default:
527+
break;
528+
}
529+
}
530+
}
531+
331532
template<typename input_t, typename output_t, typename acc_t>
332533
void dispatch_scaled_masked_softmax_forward(
333534
output_t *dst,
@@ -340,6 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
340541
int attn_heads,
341542
int pad_batches)
342543
{
544+
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
343545
if (key_seq_len == 0) {
344546
return;
345547
} else {
@@ -358,6 +560,7 @@ void dispatch_scaled_masked_softmax_forward(
358560

359561
int warps_per_block = (threads_per_block / warp_size);
360562
int batches_per_block = warps_per_block * batches_per_warp;
563+
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
361564
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
362565
dim3 threads(warp_size, warps_per_block, 1);
363566
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
410613
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
411614
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
412615
break;
616+
case 12: // 4096
617+
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
618+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
619+
break;
413620
default:
414621
break;
415622
}
@@ -427,6 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
427634
int batches,
428635
int attn_heads)
429636
{
637+
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
430638
if (key_seq_len == 0) {
431639
return;
432640
} else {
@@ -497,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
497705
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
498706
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
499707
break;
708+
case 12: // 4096
709+
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
710+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
711+
break;
712+
500713
default:
501714
break;
502715
}

megatron/fused_kernels/scaled_masked_softmax_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
4444
const int attn_heads = input.size(1);
4545
const int query_seq_len = input.size(2);
4646
const int key_seq_len = input.size(3);
47-
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
47+
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
4848
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
4949
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
5050
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);

megatron/fused_kernels/scaled_upper_triang_masked_softmax.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
125125
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
126126
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
127127
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
128-
constexpr int ELEMENTS_PER_LDG_STG = 4;
128+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
129129

130130
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
131131
int local_seq = blockIdx.x + 1;
@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
245245
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
246246
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
247247
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
248-
constexpr int ELEMENTS_PER_LDG_STG = 4;
248+
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
249249

250250
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
251251
int local_seq = blockIdx.x + 1;
@@ -340,6 +340,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
340340
int softmax_elements_stride,
341341
int attn_batches)
342342
{
343+
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
343344
if (softmax_elements == 0) {
344345
return;
345346
} else {
@@ -359,6 +360,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
359360

360361
int warps_per_block = (threads_per_block / warp_size);
361362
int batches_per_block = warps_per_block * batches_per_warp;
363+
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
364+
362365
int blocks_per_seq = attn_batches / batches_per_block;
363366
dim3 blocks(seq_len, blocks_per_seq, 1);
364367
dim3 threads(warp_size, warps_per_block, 1);
@@ -428,6 +431,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
428431
int softmax_elements_stride,
429432
int attn_batches)
430433
{
434+
TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
431435
if (softmax_elements == 0) {
432436
return;
433437
} else {
@@ -447,6 +451,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
447451

448452
int warps_per_block = (threads_per_block / warp_size);
449453
int batches_per_block = warps_per_block * batches_per_warp;
454+
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
455+
450456
int blocks_per_seq = attn_batches / batches_per_block;
451457
dim3 blocks(seq_len, blocks_per_seq, 1);
452458
dim3 threads(warp_size, warps_per_block, 1);

0 commit comments

Comments
 (0)