Skip to content

Commit 1c3ac8b

Browse files
committed
[BugFix] Fix silu kernel
1 parent adccb9d commit 1c3ac8b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

lightllm/models/llama/triton_kernel/silu_and_mul.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def _silu_and_mul_kernel(
1717
BLOCK_M: tl.constexpr,
1818
BLOCK_N: tl.constexpr,
1919
):
20-
stride_input_m = stride_input_m.to(tl.int64)
21-
stride_output_m = stride_output_m.to(tl.int64)
20+
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
21+
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
2222

2323
tid = tl.program_id(0)
2424
input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -53,7 +53,7 @@ def _silu_and_mul_kernel(
5353
)
5454

5555

56-
def silu_and_mul_fwd(input, output):
56+
def silu_and_mul_fwd(input: torch.Tensor, output):
5757
stride_input_m = input.stride(0)
5858
stride_input_n = input.stride(1)
5959
stride_output_m = output.stride(0)
@@ -88,13 +88,13 @@ def torch_silu_and_mul(input: torch.Tensor):
8888
def test_silu_and_mul(M, N, dtype, device="cuda"):
8989
# create data
9090
X = torch.randn((M, N), dtype=dtype, device=device)
91-
91+
y_tri = torch.empty((M, N // 2), dtype=dtype, device=device)
9292
# run
93-
y_tri = silu_and_mul_fwd(X)
93+
silu_and_mul_fwd(X, y_tri)
9494
y_ref = torch_silu_and_mul(X)
9595

9696
# compare
9797
print("type:", y_tri.dtype, y_ref.dtype)
9898
print("max delta:", torch.max(torch.abs(y_tri - y_ref)))
99-
assert torch.allclose(y_tri, y_ref, atol=1e-6, rtol=0)
99+
assert torch.allclose(y_tri, y_ref, atol=1e-5, rtol=0)
100100
return

0 commit comments

Comments
 (0)