|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import threading |
| 4 | +from typing import Optional, Tuple, List, Dict, Any |
| 5 | + |
| 6 | +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP |
| 7 | +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id |
| 8 | +from lightllm.common.quantization import Quantcfg |
| 9 | +from lightllm.utils.log_utils import init_logger |
| 10 | + |
| 11 | +logger = init_logger(__name__) |
| 12 | + |
| 13 | +FP4_VALUES = [ |
| 14 | + +0.0, |
| 15 | + +0.5, |
| 16 | + +1.0, |
| 17 | + +1.5, |
| 18 | + +2.0, |
| 19 | + +3.0, |
| 20 | + +4.0, |
| 21 | + +6.0, |
| 22 | + -0.0, |
| 23 | + -0.5, |
| 24 | + -1.0, |
| 25 | + -1.5, |
| 26 | + -2.0, |
| 27 | + -3.0, |
| 28 | + -4.0, |
| 29 | + -6.0, |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +class GPTOSSFusedMoeWeightTP(FusedMoeWeightTP): |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + gate_up_proj_name: str, # diff with FusedMoeWeightTP |
| 37 | + down_proj_name: str, |
| 38 | + e_score_correction_bias_name: str, |
| 39 | + weight_prefix: str, |
| 40 | + n_routed_experts: int, |
| 41 | + num_fused_shared_experts: int, |
| 42 | + split_inter_size: int, |
| 43 | + data_type: torch.dtype, |
| 44 | + network_config: Dict[str, Any], |
| 45 | + layer_num: int, |
| 46 | + world_size: int = 1, # diff with FusedMoeWeightTP |
| 47 | + quant_cfg: Quantcfg = None, |
| 48 | + ) -> None: |
| 49 | + super().__init__( |
| 50 | + gate_up_proj_name, |
| 51 | + down_proj_name, |
| 52 | + gate_up_proj_name, |
| 53 | + e_score_correction_bias_name, |
| 54 | + weight_prefix, |
| 55 | + n_routed_experts, |
| 56 | + num_fused_shared_experts, |
| 57 | + split_inter_size, |
| 58 | + data_type, |
| 59 | + network_config, |
| 60 | + layer_num, |
| 61 | + quant_cfg, |
| 62 | + ) |
| 63 | + self.hidden_size = network_config["hidden_size"] |
| 64 | + |
| 65 | + self.alpha = 1.702 |
| 66 | + self.limit = 7.0 |
| 67 | + self.tp_world_size_ = world_size |
| 68 | + |
| 69 | + self.w1_bias = None |
| 70 | + self.w2_bias = None |
| 71 | + |
| 72 | + self._down_bias_name = f"{weight_prefix}.{down_proj_name}_bias" |
| 73 | + self._down_blocks_name = f"{weight_prefix}.{down_proj_name}_blocks" |
| 74 | + self._down_scales_name = f"{weight_prefix}.{down_proj_name}_scales" |
| 75 | + self._gate_up_bias_name = f"{weight_prefix}.{gate_up_proj_name}_bias" |
| 76 | + self._gate_up_blocks_name = f"{weight_prefix}.{gate_up_proj_name}_blocks" |
| 77 | + self._gate_up_scales_name = f"{weight_prefix}.{gate_up_proj_name}_scales" |
| 78 | + return |
| 79 | + |
| 80 | + def _fuse_weight_scale(self): |
| 81 | + assert False, "Not implemented for GPT-OSS." |
| 82 | + |
| 83 | + def _fuse(self): |
| 84 | + assert False, "Not implemented for GPT-OSS." |
| 85 | + |
| 86 | + def load_hf_weights(self, weights): |
| 87 | + if ( |
| 88 | + weights.get(self._down_blocks_name, None) is not None |
| 89 | + and weights.get(self._down_scales_name, None) is not None |
| 90 | + ): |
| 91 | + w2 = self._convert_moe_packed_tensors( |
| 92 | + blocks=weights[self._down_blocks_name], |
| 93 | + scales=weights[self._down_scales_name], |
| 94 | + dtype=torch.bfloat16, |
| 95 | + )[:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :] |
| 96 | + self.w2 = (self._cuda(w2), None) |
| 97 | + |
| 98 | + if ( |
| 99 | + weights.get(self._gate_up_blocks_name, None) is not None |
| 100 | + and weights.get(self._gate_up_scales_name, None) is not None |
| 101 | + ): |
| 102 | + w1 = self._convert_moe_packed_tensors( |
| 103 | + blocks=weights[self._gate_up_blocks_name], |
| 104 | + scales=weights[self._gate_up_scales_name], |
| 105 | + dtype=torch.bfloat16, |
| 106 | + )[:, :, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2] |
| 107 | + self.w1 = (self._cuda(w1), None) |
| 108 | + |
| 109 | + if weights.get(self._gate_up_bias_name, None) is not None: |
| 110 | + w1_bias = weights[self._gate_up_bias_name][ |
| 111 | + :, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2 |
| 112 | + ] |
| 113 | + self.w1_bias = self._cuda(w1_bias) |
| 114 | + |
| 115 | + if weights.get(self._down_bias_name, None) is not None: |
| 116 | + w2_bias = weights[self._down_bias_name] |
| 117 | + self.w2_bias = self._cuda(w2_bias) |
| 118 | + |
| 119 | + def experts(self, hidden_states: torch.Tensor, routing_weights, layer_num): |
| 120 | + w1, w1_scale = self.w1 |
| 121 | + w2, w2_scale = self.w2 |
| 122 | + assert w1_scale is None and w2_scale is None, "For now, we do not support quantized weight in GPT-OSS." |
| 123 | + |
| 124 | + batch_size = hidden_states.shape[0] |
| 125 | + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) |
| 126 | + num_experts = routing_weights.shape[1] |
| 127 | + |
| 128 | + hidden_states = hidden_states.repeat(num_experts, 1) |
| 129 | + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) |
| 130 | + gate_up = torch.bmm(hidden_states, w1) + self.w1_bias[..., None, :] |
| 131 | + gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
| 132 | + gate = gate.clamp(min=None, max=self.limit) |
| 133 | + up = up.clamp(min=-self.limit, max=self.limit) |
| 134 | + glu = gate * torch.sigmoid(gate * self.alpha) |
| 135 | + next_states = torch.bmm(((up + 1) * glu), w2) |
| 136 | + next_states = next_states + self.w2_bias[..., None, :] / self.tp_world_size_ |
| 137 | + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) |
| 138 | + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] |
| 139 | + next_states = next_states.sum(dim=0) |
| 140 | + return next_states |
| 141 | + |
| 142 | + def _convert_moe_packed_tensors( |
| 143 | + self, |
| 144 | + blocks, |
| 145 | + scales, |
| 146 | + *, |
| 147 | + dtype: torch.dtype = torch.bfloat16, |
| 148 | + rows_per_chunk: int = 32768 * 1024, |
| 149 | + ) -> torch.Tensor: |
| 150 | + """ |
| 151 | + Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward |
| 152 | + pass of GPT_OSS. |
| 153 | + """ |
| 154 | + import math |
| 155 | + |
| 156 | + # Check if blocks and scales are on CPU, and move to GPU if so |
| 157 | + if not blocks.is_cuda and torch.cuda.is_available(): |
| 158 | + blocks = blocks.cuda() |
| 159 | + scales = scales.cuda() |
| 160 | + |
| 161 | + scales = scales.to(torch.int32) - 127 # that's because 128=2**7 |
| 162 | + |
| 163 | + assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}" |
| 164 | + |
| 165 | + lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device) |
| 166 | + |
| 167 | + *prefix_shape, G, B = blocks.shape |
| 168 | + rows_total = math.prod(prefix_shape) * G |
| 169 | + |
| 170 | + blocks = blocks.reshape(rows_total, B) |
| 171 | + scales = scales.reshape(rows_total, 1) |
| 172 | + |
| 173 | + out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device) |
| 174 | + |
| 175 | + for r0 in range(0, rows_total, rows_per_chunk): |
| 176 | + r1 = min(r0 + rows_per_chunk, rows_total) |
| 177 | + |
| 178 | + blk = blocks[r0:r1] |
| 179 | + exp = scales[r0:r1] |
| 180 | + |
| 181 | + # nibble indices -> int64 |
| 182 | + idx_lo = (blk & 0x0F).to(torch.long) |
| 183 | + idx_hi = (blk >> 4).to(torch.long) |
| 184 | + |
| 185 | + sub = out[r0:r1] |
| 186 | + sub[:, 0::2] = lut[idx_lo] |
| 187 | + sub[:, 1::2] = lut[idx_hi] |
| 188 | + |
| 189 | + torch.ldexp(sub, exp, out=sub) |
| 190 | + del idx_lo, idx_hi, blk, exp, sub |
| 191 | + |
| 192 | + out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2) |
| 193 | + del blocks, scales, lut |
| 194 | + return out.transpose(1, 2).contiguous() |
0 commit comments