Skip to content

Commit 9c323ac

Browse files
committed
add rmsnorm-add fusion kernel
1 parent 250318a commit 9c323ac

File tree

3 files changed

+342
-1
lines changed

3 files changed

+342
-1
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def _fwd_fused_add_rmsnorm(
8+
original,
9+
residual,
10+
weight,
11+
original_stride0,
12+
original_stride1,
13+
residual_stride0,
14+
residual_stride1,
15+
N, # number of columns in X
16+
eps,
17+
BLOCK_SIZE: tl.constexpr,
18+
):
19+
block_id = tl.program_id(0)
20+
# data's base address of this block
21+
_original = original + block_id * original_stride0
22+
_residual = residual + block_id * residual_stride0
23+
24+
# avoid repeat loading from gmem to smem
25+
# in some very large size, have better performance
26+
if N <= BLOCK_SIZE:
27+
# data's offset address of this block
28+
range = tl.arange(0, BLOCK_SIZE)
29+
_original_offset = range * original_stride1
30+
_residual_offset = range * residual_stride1
31+
_weight_offset = range
32+
33+
# data's pointers of this block
34+
_original_ptr = _original + _original_offset
35+
_residual_ptr = _residual + _residual_offset
36+
_weight_ptr = weight + _weight_offset
37+
38+
# load data from memory
39+
mask = range < N
40+
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
41+
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)
42+
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)
43+
44+
# store (original + residual) to original
45+
original_cache = original_cache + residual_cache
46+
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)
47+
48+
# compute variance
49+
var = tl.sum(original_cache * original_cache) / N
50+
rstd = 1 / tl.sqrt(var + eps)
51+
residual_cache = original_cache * rstd * weight_cache
52+
53+
# store rmsnorm(original + residual) back to residual
54+
tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask)
55+
else:
56+
sum_of_squares = tl.zeros([], dtype=tl.float32)
57+
for block_offset in range(0, N, BLOCK_SIZE):
58+
# data's offset address of this block
59+
range = tl.arange(0, BLOCK_SIZE) + block_offset
60+
_original_offset = range * original_stride1
61+
_residual_offset = range * residual_stride1
62+
63+
# data's pointers of this block
64+
_original_ptr = _original + _original_offset
65+
_residual_ptr = _residual + _residual_offset
66+
67+
# load data from memory
68+
mask = range < N
69+
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
70+
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)
71+
72+
# store (original + residual) to original
73+
original_cache = original_cache + residual_cache
74+
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)
75+
76+
# compute sum_of_squares
77+
sum_of_squares += tl.sum(original_cache * original_cache)
78+
79+
# compute variance
80+
var = sum_of_squares / N
81+
rstd = 1 / tl.sqrt(var + eps)
82+
83+
for block_offset in range(0, N, BLOCK_SIZE):
84+
# data's offset address of this block
85+
range = tl.arange(0, BLOCK_SIZE) + block_offset
86+
_original_offset = range * original_stride1
87+
_residual_offset = range * residual_stride1
88+
_weight_offset = range
89+
90+
# data's pointers of this block
91+
_original_ptr = _original + _original_offset
92+
_residual_ptr = _residual + _residual_offset
93+
_weight_ptr = weight + _weight_offset
94+
95+
# load data from memory
96+
mask = range < N
97+
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
98+
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)
99+
100+
# apply rmsnorm using pre-computed rstd
101+
original_cache = original_cache * rstd * weight_cache
102+
103+
# store rmsnorm(original) back to residual
104+
tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask)
105+
106+
107+
def fused_add_rmsnorm_inplace(
108+
original: torch.Tensor, # [num_tokens, hidden_size]
109+
residual: torch.Tensor,
110+
weight: torch.Tensor,
111+
eps: float,
112+
):
113+
"""
114+
Perform fused add & rmsnorm
115+
116+
suppose the skip connection result is H(x) = F(x) + x,
117+
then F(x) is the residual, x is the original.
118+
Here original will be (residual + original), residual will be rmsnorm(residual + original)
119+
At first Layer, residual should be all zeros.
120+
"""
121+
# reshape input data into 2D tensor
122+
original_arg = original.view(-1, original.shape[-1])
123+
residual_arg = residual.view(-1, residual.shape[-1])
124+
125+
assert original.data_ptr() == original_arg.data_ptr()
126+
assert residual.data_ptr() == residual_arg.data_ptr()
127+
128+
M, N = original_arg.shape
129+
# Less than 64KB per feature: enqueue fused kernel
130+
MAX_FUSED_SIZE = 65536 // original.element_size()
131+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
132+
133+
if N > BLOCK_SIZE:
134+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
135+
136+
# heuristics for number of warps
137+
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
138+
num_warps = triton.next_power_of_2(num_warps)
139+
if BLOCK_SIZE > 16384:
140+
BLOCK_SIZE = 16384
141+
142+
# enqueue kernel
143+
_fwd_fused_add_rmsnorm[(M,)](
144+
original_arg,
145+
residual_arg,
146+
weight,
147+
original_arg.stride(0),
148+
original_arg.stride(1),
149+
residual_arg.stride(0),
150+
residual_arg.stride(1),
151+
N, # number of columns in X
152+
eps,
153+
BLOCK_SIZE=BLOCK_SIZE,
154+
num_warps=num_warps,
155+
)

lightllm/utils/custom_kernel_utis.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import triton
33
import triton.language as tl
4-
from typing import List
4+
from typing import List, Callable
55

66

77
def custom_cat(tensors):
@@ -125,3 +125,126 @@ def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int):
125125
out[0:origin_batch_size, :] = input
126126
out[origin_batch_size:, :] = input[0:1, :]
127127
return out
128+
129+
130+
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor:
131+
"""
132+
Compute SNR between y_pred(tensor) and y_real(tensor)
133+
134+
SNR can be calcualted as following equation:
135+
136+
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
137+
138+
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
139+
140+
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
141+
142+
Args:
143+
y_pred (torch.Tensor): _description_
144+
y_real (torch.Tensor): _description_
145+
reduction (str, optional): _description_. Defaults to 'mean'.
146+
147+
Raises:
148+
ValueError: _description_
149+
ValueError: _description_
150+
151+
Returns:
152+
torch.Tensor: _description_
153+
"""
154+
y_pred = torch.flatten(y_pred).float()
155+
y_real = torch.flatten(y_real).float()
156+
157+
if y_pred.shape != y_real.shape:
158+
raise ValueError(
159+
f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})"
160+
)
161+
162+
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
163+
signal_power = torch.pow(y_real, 2).sum(dim=-1)
164+
snr = (noise_power) / (signal_power + 1e-7)
165+
return snr.item()
166+
167+
168+
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
169+
"""
170+
A decorator function to assist in performance testing of CUDA operations.
171+
172+
This function will:
173+
1. Automatically determine whether any parameters in the argument list,
174+
or the output of the `func`, are of type `torch.Tensor`.
175+
2. If so, calculate the memory usage of the input and output tensors
176+
on the GPU (based on their data type and `torch.numel()`).
177+
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations.
178+
4. Record the execution time during these iterations.
179+
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.
180+
181+
Args:
182+
func (function): The function to benchmark.
183+
shape (list of int): The problem shape.
184+
tflops (float): The computational workload (in TFLOPS) per call of `func`.
185+
steps (int): The number of times the function is executed during benchmarking.
186+
*args: Positional arguments to be passed to the `func`.
187+
**kwargs: Keyword arguments to be passed to the `func`.
188+
189+
Returns:
190+
function result
191+
"""
192+
193+
# Ensure CUDA is available
194+
if not torch.cuda.is_available():
195+
raise RuntimeError("CUDA is required for benchmarking.")
196+
197+
# Check for torch.Tensor in inputs and outputs
198+
input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
199+
input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)]
200+
201+
def calculate_memory(tensor: torch.Tensor):
202+
"""Calculate memory usage in bytes for a tensor."""
203+
return tensor.numel() * tensor.element_size()
204+
205+
input_memory = sum(calculate_memory(t) for t in input_tensors)
206+
207+
# Execute the function to inspect outputs
208+
with torch.no_grad():
209+
output = func(*args, **kwargs)
210+
211+
output_memory = 0
212+
if isinstance(output, torch.Tensor):
213+
output_memory = calculate_memory(output)
214+
elif isinstance(output, (list, tuple)):
215+
output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor))
216+
217+
total_memory = input_memory + output_memory
218+
219+
# Warm-up and CUDA graph creation
220+
for _ in range(10): # Warm-up
221+
func(*args, **kwargs)
222+
223+
torch.cuda.synchronize() # Ensure no pending operations
224+
225+
# Benchmark the function
226+
start_event = torch.cuda.Event(enable_timing=True)
227+
end_event = torch.cuda.Event(enable_timing=True)
228+
229+
start_event.record()
230+
for _ in range(steps):
231+
func(*args, **kwargs)
232+
end_event.record()
233+
234+
torch.cuda.synchronize() # Ensure all operations are finished
235+
elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds
236+
237+
# Calculate performance metrics
238+
elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds
239+
avg_time_per_step = elapsed_time_s / steps
240+
compute_performance = tflops / avg_time_per_step # TFLOPS
241+
memory_throughput = (total_memory * steps / (1024 ** 3)) / elapsed_time_s # GB/s
242+
243+
# Print performance metrics
244+
print(f"Function: {func.__name__}{shape}")
245+
# print(f"Function: {func.__ne__}{shape}")
246+
print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds")
247+
print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms")
248+
print(f"Compute Performance: {compute_performance:.2f} TFLOPS")
249+
print(f"Memory Throughput: {memory_throughput:.2f} GB/s")
250+
print("") # print a blank line.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
import torch
3+
from lightllm.models.llama.triton_kernel.fused_add_rmsnorm_inplace import fused_add_rmsnorm_inplace
4+
from lightllm.utils.custom_kernel_utis import benchmark, error
5+
6+
7+
class TestFusedAddRmsNormInplace(unittest.TestCase):
8+
def setUp(self):
9+
"""Set up common test parameters."""
10+
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384]
11+
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800]
12+
self.device = "cuda"
13+
self.dtype = torch.bfloat16
14+
15+
def torch_add_rmsnorm(self, X, R, W):
16+
X.add_(R)
17+
return torch.nn.functional.rms_norm(X, (X.shape[1],), W, eps=1e-6)
18+
19+
def test_accuracy(self):
20+
"""Test the accuracy of fused_add_rmsnorm_inplace against torch.rmsnorm."""
21+
for token_num in self.tokens:
22+
for dim in self.dims:
23+
with self.subTest(shape=[token_num, dim]):
24+
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
25+
_X = X.clone()
26+
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
27+
_R = R.clone()
28+
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)
29+
30+
r_real = self.torch_add_rmsnorm(_X, _R, W)
31+
fused_add_rmsnorm_inplace(X, R, W, eps=1e-6)
32+
r_pred = R
33+
self.assertTrue(
34+
error(r_pred, r_real) < 0.01,
35+
f"Accuracy test failed for size {token_num}, {dim}. r_real={r_real}, r_pred={r_pred}",
36+
)
37+
print(f"{error(r_pred, r_real) = }")
38+
39+
x_real = _X
40+
x_pred = X
41+
self.assertTrue(
42+
error(x_pred, x_real) < 0.01,
43+
f"Accuracy test failed for size {token_num}, {dim}. x_real={x_real}, x_pred={x_pred}",
44+
)
45+
print(f"{error(x_pred, x_real) = }")
46+
47+
def test_performance(self):
48+
"""Test the performance of rmsnorm using benchmark."""
49+
for token_num in self.tokens:
50+
for dim in self.dims:
51+
with self.subTest(shape=[token_num, dim]):
52+
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
53+
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
54+
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)
55+
56+
shape = [token_num, dim]
57+
tflops = 0.0
58+
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W)
59+
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6)
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()

0 commit comments

Comments
 (0)