-
Notifications
You must be signed in to change notification settings - Fork 296
[quant] deepgemm-fp8w8a8-b128 quantize #952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,11 +41,15 @@ def __init__(self): | |
| self.act_scale_suffix = None # no support for static input tensor scale for ds model. | ||
|
|
||
| def quantize(self, weight: torch.Tensor): | ||
|
|
||
| raise Exception("Not implemented") | ||
| from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return weight_quant(weight, self.block_size) | ||
|
|
||
| def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_custom_tensor_mananger=True): | ||
| qweight, weight_scale, input_scale = weights | ||
| if len(weights) == 3: | ||
| qweight, weight_scale, input_scale = weights | ||
| else: | ||
| qweight, weight_scale = weights | ||
| input_scale = None | ||
| m, k = input_tensor.shape | ||
| n = weights[0].shape[1] | ||
| if input_scale is None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
| from lightllm.utils.dist_utils import get_current_device_id | ||
|
|
||
| @triton.jit | ||
| def weight_quant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): | ||
| pid_m = tl.program_id(axis=0) | ||
| pid_n = tl.program_id(axis=1) | ||
| n_blocks = tl.cdiv(N, BLOCK_SIZE) | ||
|
|
||
| offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| offs = offs_m[:, None] * N + offs_n[None, :] | ||
| mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
|
|
||
| x = tl.load(x_ptr + offs, mask=mask, other=0.0).to(tl.float32) | ||
|
|
||
| amax = tl.max(tl.abs(x)) | ||
|
|
||
| max_fp8e4m3_val = 448.0 | ||
| scale = amax / (max_fp8e4m3_val + 1e-6) | ||
|
|
||
| y = (x / scale).to(y_ptr.dtype.element_ty) | ||
sufubao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| tl.store(y_ptr + offs, y, mask=mask) | ||
| tl.store(s_ptr + pid_m * n_blocks + pid_n, scale) | ||
|
|
||
|
|
||
| def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert x.is_contiguous(), 'Input tensor must be contiguous' | ||
| M, N = x.size() | ||
|
|
||
| y_quant = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=x.device) | ||
|
|
||
| num_blocks_m = triton.cdiv(M, block_size) | ||
| num_blocks_n = triton.cdiv(N, block_size) | ||
| s_scales = torch.empty((num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) | ||
| weight_quant_kernel[grid](x, s_scales, y_quant, M, N, BLOCK_SIZE=block_size) | ||
| return y_quant, s_scales | ||
|
|
||
|
|
||
| def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert x.is_contiguous(), 'Input tensor must be contiguous' | ||
| x = x.cuda(get_current_device_id()) | ||
| if x.dim() == 3: | ||
| y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) | ||
| num_blocks_m = triton.cdiv(x.shape[1], block_size) | ||
| num_blocks_n = triton.cdiv(x.shape[2], block_size) | ||
| s_scales = torch.empty((x.shape[0], num_blocks_m, num_blocks_n), dtype=torch.float32, device=x.device) | ||
| for i in range(x.shape[0]): | ||
| y_quant[i], s_scales[i] = mm_weight_quant(x[i], block_size) | ||
| return y_quant, s_scales | ||
| else: | ||
| y_quant, s_scales = mm_weight_quant(x, block_size) | ||
| return y_quant.t(), s_scales.t() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight | ||
| from lightllm.models.qwen3.model import Qwen3TpPartModel | ||
| from lightllm.utils.log_utils import init_logger | ||
| from lightllm.distributed.communication_op import dist_group_manager | ||
|
|
||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -21,3 +22,12 @@ class Qwen3MOEModel(Qwen3TpPartModel): | |
| def __init__(self, kvargs): | ||
| super().__init__(kvargs) | ||
| return | ||
|
|
||
| def _init_config(self): | ||
| super()._init_config() | ||
| # self.config["num_hidden_layers"] = 2 | ||
| # self.config["n_layer"] = 2 | ||
|
||
|
|
||
| def _init_custom(self): | ||
| super()._init_custom() | ||
| dist_group_manager.new_deepep_group(256, self.config["hidden_size"]) | ||
sufubao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,6 +77,18 @@ def init_vision_distributed_env(kvargs): | |
| del _a | ||
|
|
||
|
|
||
| def run_once(func): | ||
| has_run = False | ||
| def wrapper(*args, **kwargs): | ||
| nonlocal has_run | ||
| if not has_run: | ||
| has_run = True | ||
| return func(*args, **kwargs) | ||
| else: | ||
| return None | ||
| return wrapper | ||
|
||
|
|
||
| @run_once | ||
| def init_distributed_env(kvargs): | ||
| assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes" | ||
| node_world_size = kvargs["world_size"] // kvargs["args"].nnodes | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying
network_configdirectly can cause unexpected side effects. Instead, retrieve the values and assign them to instance attributes. This prevents unintended modifications to the original dictionary.