Skip to content

Commit 3b60919

Browse files
committed
add gelt and rms_norm
1 parent b237506 commit 3b60919

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

lightllm/models/vit/layer_infer/post_layer_infer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.distributed as dist
44
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
55
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
6-
6+
from lightllm.models.vit.triton_kernel.gelu_vit import gelu
77

88
class ViTPostLayerInfer:
99
""" """
@@ -44,8 +44,9 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
4444
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
4545
)
4646

47-
vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
48-
47+
# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
48+
vit_embeds_gelu = gelu(vit_embeds_1)
49+
4950
vit_embeds_out = torch.addmm(
5051
layer_weight.mlp1_3_bias_,
5152
vit_embeds_gelu.view(-1, self.llm_hidden_size // self.tp_world_size_),

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
1111
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
1212
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
13-
13+
from lightllm.models.vit.triton_kernel.gelu_vit import gelu
14+
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
1415

1516
class ViTTransformerLayerInfer:
1617
""" """
@@ -58,7 +59,7 @@ def tp_norm(self, input, weight):
5859

5960
def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
6061
if layer_weight.norm_type == "rms_norm":
61-
b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
62+
b = rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
6263
else:
6364
b = torch.nn.functional.layer_norm(
6465
input,
@@ -71,7 +72,7 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7172

7273
def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
7374
if layer_weight.norm_type == "rms_norm":
74-
return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
75+
return rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
7576
else:
7677
return torch.nn.functional.layer_norm(
7778
input,
@@ -113,7 +114,8 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor
113114

114115
def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
115116
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
116-
ffn1_out = torch.nn.functional.gelu(fc1)
117+
# ffn1_out = torch.nn.functional.gelu(fc1)
118+
ffn1_out = gelu(fc1)
117119
input_shape = input.shape
118120
input = None
119121
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
7+
@triton.jit
8+
def gelu(x):
9+
x_fp32 = x.to(tl.float32)
10+
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811))
11+
return x_gelu
12+
13+
# 定义 Triton 内核
14+
@triton.jit
15+
def gelu_kernel(output_ptr, input_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
16+
pid = tl.program_id(axis=0)
17+
block_start = pid * BLOCK_SIZE
18+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
19+
mask = offsets < n_elements
20+
input = tl.load(input_ptr + offsets, mask=mask)
21+
output = gelu(input)
22+
tl.store(output_ptr + offsets, output, mask=mask)
23+
24+
# 自定义 torch.autograd.Function
25+
class GeluTriton(torch.autograd.Function):
26+
@staticmethod
27+
def forward(ctx, input):
28+
output = torch.empty_like(input)
29+
n_elements = input.numel()
30+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
31+
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)
32+
return output
33+
34+
gelu = GeluTriton.apply
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
from torch import Tensor
5+
6+
7+
@triton.jit
8+
def rms_norm_kernel(
9+
input,
10+
weight,
11+
output,
12+
input_row_stride: tl.constexpr,
13+
eps: tl.constexpr,
14+
N_COLS: tl.constexpr,
15+
BLOCK_N: tl.constexpr,
16+
):
17+
"""Rms norm kernel."""
18+
prog_id = tl.program_id(0)
19+
offsets = tl.arange(0, BLOCK_N)
20+
21+
w = tl.load(weight + offsets, mask=offsets < N_COLS)
22+
23+
x_ptr = input + prog_id * input_row_stride
24+
x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
25+
xf = x.to(tl.float32)
26+
27+
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
28+
out = xf / tl.sqrt(var + eps)
29+
out = (w * out).to(x.dtype)
30+
31+
out_ptr = output + prog_id * input_row_stride
32+
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
33+
34+
35+
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5):
36+
"""Rms norm."""
37+
feat_size = weight.shape[0]
38+
seq_len = hidden_states.numel() // hidden_states.size(-1)
39+
input_stride = hidden_states.stride(-2)
40+
41+
BLOCK_N = triton.next_power_of_2(feat_size)
42+
out = torch.empty_like(hidden_states)
43+
44+
grid = (seq_len,)
45+
rms_norm_kernel[grid](
46+
hidden_states,
47+
weight,
48+
out,
49+
input_row_stride=input_stride,
50+
eps=eps,
51+
N_COLS=feat_size,
52+
BLOCK_N=BLOCK_N,
53+
num_warps=4,
54+
num_stages=3,
55+
)
56+
57+
return out

0 commit comments

Comments
 (0)