-
Notifications
You must be signed in to change notification settings - Fork 190
[OMNIML-2857] Support the DeepSeek V3.2 model #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
19722c9
4e4bf16
90865c3
9b64663
ea9190e
2223d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| DeepSeek-V3/ | ||
| DeepSeek-V3.2-Exp/ |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,58 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import triton | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import triton.language as tl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @triton.jit | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
meenchen marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Dequantizes weights using the provided scaling factors and stores the result. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_ptr (tl.pointer): Pointer to the quantized weights. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_ptr (tl.pointer): Pointer to the scaling factors. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M (int): Number of rows in the weight matrix. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N (int): Number of columns in the weight matrix. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BLOCK_SIZE (tl.constexpr): Size of the block for tiling. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pid_m = tl.program_id(axis=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pid_n = tl.program_id(axis=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n = tl.cdiv(N, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offs = offs_m[:, None] * N + offs_n[None, :] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s = tl.load(s_ptr + pid_m * n + pid_n) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y = x * s | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tl.store(y_ptr + offs, y, mask=mask) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Dequantizes the given weight tensor using the provided scale tensor. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x (torch.Tensor): The quantized weight tensor of shape (M, N). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_size (int, optional): The block size to use for dequantization. Defaults to 128. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.Tensor: The dequantized weight tensor of the same shape as `x`. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| M, N = x.size() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| y = torch.empty_like(x, dtype=torch.get_default_dtype()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return y | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+79
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Block-scale tensor shape can go out-of-bounds The docstring and lack of shape checks let callers size Please validate the shape up front (using ceil-div) and update the docstring accordingly, e.g.: @@
- s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
+ s (torch.Tensor): The scale tensor of shape (ceil_div(M, block_size), ceil_div(N, block_size)).
@@
- M, N = x.size()
+ M, N = x.size()
+ m_blocks = (M + block_size - 1) // block_size
+ n_blocks = (N + block_size - 1) // block_size
+ assert s.size() == (m_blocks, n_blocks), \
+ f"Expected s.shape == ({m_blocks}, {n_blocks}), got {tuple(s.size())}"This keeps the kernel within bounds and matches its launch configuration. 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.