|
18 | 18 | ROWBMMWeightNoTp, |
19 | 19 | ) |
20 | 20 | from functools import partial |
21 | | - |
22 | | -import triton |
23 | | -import triton.language as tl |
24 | | -from triton import Config |
25 | | - |
26 | | - |
27 | | -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: |
28 | | - """ |
29 | | - Dequantizes the given weight tensor using the provided scale tensor. |
30 | | -
|
31 | | - Args: |
32 | | - x (torch.Tensor): The quantized weight tensor of shape (M, N). |
33 | | - s (torch.Tensor): The scale tensor of shape (M, N). |
34 | | - block_size (int, optional): The block size to use for dequantization. Defaults to 128. |
35 | | -
|
36 | | - Returns: |
37 | | - torch.Tensor: The dequantized weight tensor of the same shape as `x`. |
38 | | -
|
39 | | - Raises: |
40 | | - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. |
41 | | - """ |
42 | | - assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" |
43 | | - assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" |
44 | | - M, N = x.size() |
45 | | - y = torch.empty_like(x, dtype=torch.get_default_dtype()) |
46 | | - grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) |
47 | | - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) |
48 | | - return y.to(torch.bfloat16) |
49 | | - |
50 | | - |
51 | | -@triton.jit |
52 | | -def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): |
53 | | - """ |
54 | | - Dequantizes weights using the provided scaling factors and stores the result. |
55 | | -
|
56 | | - Args: |
57 | | - x_ptr (tl.pointer): Pointer to the quantized weights. |
58 | | - s_ptr (tl.pointer): Pointer to the scaling factors. |
59 | | - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. |
60 | | - M (int): Number of rows in the weight matrix. |
61 | | - N (int): Number of columns in the weight matrix. |
62 | | - BLOCK_SIZE (tl.constexpr): Size of the block for tiling. |
63 | | -
|
64 | | - Returns: |
65 | | - None |
66 | | - """ |
67 | | - pid_m = tl.program_id(axis=0) |
68 | | - pid_n = tl.program_id(axis=1) |
69 | | - n = tl.cdiv(N, BLOCK_SIZE) |
70 | | - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
71 | | - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
72 | | - offs = offs_m[:, None] * N + offs_n[None, :] |
73 | | - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
74 | | - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
75 | | - s = tl.load(s_ptr + pid_m * n + pid_n) |
76 | | - y = x * s |
77 | | - tl.store(y_ptr + offs, y, mask=mask) |
| 21 | +from ..triton_kernel.weight_dequant import weight_dequant |
78 | 22 |
|
79 | 23 |
|
80 | 24 | class Deepseek2TransformerLayerWeight(TransformerLayerWeight): |
|
0 commit comments