Skip to content

Commit d615f11

Browse files
committed
fix
1 parent dd0bf1b commit d615f11

File tree

2 files changed

+60
-57
lines changed

2 files changed

+60
-57
lines changed

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,63 +18,7 @@
1818
ROWBMMWeightNoTp,
1919
)
2020
from functools import partial
21-
22-
import triton
23-
import triton.language as tl
24-
from triton import Config
25-
26-
27-
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
28-
"""
29-
Dequantizes the given weight tensor using the provided scale tensor.
30-
31-
Args:
32-
x (torch.Tensor): The quantized weight tensor of shape (M, N).
33-
s (torch.Tensor): The scale tensor of shape (M, N).
34-
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
35-
36-
Returns:
37-
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
38-
39-
Raises:
40-
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
41-
"""
42-
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
43-
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
44-
M, N = x.size()
45-
y = torch.empty_like(x, dtype=torch.get_default_dtype())
46-
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
47-
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
48-
return y.to(torch.bfloat16)
49-
50-
51-
@triton.jit
52-
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
53-
"""
54-
Dequantizes weights using the provided scaling factors and stores the result.
55-
56-
Args:
57-
x_ptr (tl.pointer): Pointer to the quantized weights.
58-
s_ptr (tl.pointer): Pointer to the scaling factors.
59-
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
60-
M (int): Number of rows in the weight matrix.
61-
N (int): Number of columns in the weight matrix.
62-
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
63-
64-
Returns:
65-
None
66-
"""
67-
pid_m = tl.program_id(axis=0)
68-
pid_n = tl.program_id(axis=1)
69-
n = tl.cdiv(N, BLOCK_SIZE)
70-
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
71-
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
72-
offs = offs_m[:, None] * N + offs_n[None, :]
73-
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
74-
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
75-
s = tl.load(s_ptr + pid_m * n + pid_n)
76-
y = x * s
77-
tl.store(y_ptr + offs, y, mask=mask)
21+
from ..triton_kernel.weight_dequant import weight_dequant
7822

7923

8024
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# adapt from
2+
# https://github.com/deepseek-ai/DeepSeek-V3/blob/f09f5fa321f5a421704136c0463b1eaca6557712/inference/kernel.py
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
from triton import Config
7+
8+
9+
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
10+
"""
11+
Dequantizes the given weight tensor using the provided scale tensor.
12+
13+
Args:
14+
x (torch.Tensor): The quantized weight tensor of shape (M, N).
15+
s (torch.Tensor): The scale tensor of shape (M, N).
16+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
17+
18+
Returns:
19+
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
20+
21+
Raises:
22+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
23+
"""
24+
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
25+
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
26+
M, N = x.size()
27+
y = torch.empty_like(x, dtype=torch.get_default_dtype())
28+
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
29+
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
30+
return y.to(torch.bfloat16)
31+
32+
33+
@triton.jit
34+
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
35+
"""
36+
Dequantizes weights using the provided scaling factors and stores the result.
37+
38+
Args:
39+
x_ptr (tl.pointer): Pointer to the quantized weights.
40+
s_ptr (tl.pointer): Pointer to the scaling factors.
41+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
42+
M (int): Number of rows in the weight matrix.
43+
N (int): Number of columns in the weight matrix.
44+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
45+
46+
Returns:
47+
None
48+
"""
49+
pid_m = tl.program_id(axis=0)
50+
pid_n = tl.program_id(axis=1)
51+
n = tl.cdiv(N, BLOCK_SIZE)
52+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
53+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
54+
offs = offs_m[:, None] * N + offs_n[None, :]
55+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
56+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
57+
s = tl.load(s_ptr + pid_m * n + pid_n)
58+
y = x * s
59+
tl.store(y_ptr + offs, y, mask=mask)

0 commit comments

Comments
 (0)