|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" |
| 6 | + |
| 7 | + |
| 8 | +@triton.jit |
| 9 | +def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): |
| 10 | + """ |
| 11 | + Dequantizes weights using the provided scaling factors and stores the result. |
| 12 | +
|
| 13 | + Args: |
| 14 | + x_ptr (tl.pointer): Pointer to the quantized weights. |
| 15 | + s_ptr (tl.pointer): Pointer to the scaling factors. |
| 16 | + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. |
| 17 | + M (int): Number of rows in the weight matrix. |
| 18 | + N (int): Number of columns in the weight matrix. |
| 19 | + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. |
| 20 | +
|
| 21 | + Returns: |
| 22 | + None |
| 23 | + """ |
| 24 | + pid_m = tl.program_id(axis=0) |
| 25 | + pid_n = tl.program_id(axis=1) |
| 26 | + n = tl.cdiv(N, BLOCK_SIZE) |
| 27 | + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 28 | + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 29 | + offs = offs_m[:, None] * N + offs_n[None, :] |
| 30 | + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) |
| 31 | + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
| 32 | + s = tl.load(s_ptr + pid_m * n + pid_n) |
| 33 | + y = x * s |
| 34 | + tl.store(y_ptr + offs, y, mask=mask) |
| 35 | + |
| 36 | + |
| 37 | +def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: |
| 38 | + """ |
| 39 | + Dequantizes the given weight tensor using the provided scale tensor. |
| 40 | +
|
| 41 | + Args: |
| 42 | + x (torch.Tensor): The quantized weight tensor of shape (M, N). |
| 43 | + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). |
| 44 | + block_size (int, optional): The block size to use for dequantization. Defaults to 128. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + torch.Tensor: The dequantized weight tensor of the same shape as `x`. |
| 48 | +
|
| 49 | + Raises: |
| 50 | + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. |
| 51 | + """ |
| 52 | + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" |
| 53 | + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" |
| 54 | + M, N = x.size() |
| 55 | + y = torch.empty_like(x, dtype=torch.get_default_dtype()) |
| 56 | + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) |
| 57 | + weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) |
| 58 | + return y |
0 commit comments