diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 0fb9c5d5f..7fab3798d 100644 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -416,6 +416,143 @@ def forward( return attn_output, attn_weights, past_key_value +class LlmcDeepSeekV2MoEGate(nn.Module): + def __init__(self, module): + super().__init__() + self.config = module.config + self.top_k = module.config.num_experts_per_tok + self.n_routed_experts = module.config.n_routed_experts + self.routed_scaling_factor = module.config.routed_scaling_factor + self.scoring_func = module.config.scoring_func + self.alpha = module.config.aux_loss_alpha + self.seq_aux = module.config.seq_aux + self.topk_method = module.config.topk_method + self.n_group = module.config.n_group + self.topk_group = module.config.topk_group + + # topk selection algorithm + self.norm_topk_prob = module.config.norm_topk_prob + self.gating_dim = module.config.hidden_size + self.fc = getattr(module, 'fc', + nn.Linear(self.gating_dim, self.n_routed_experts, bias=False)) + if not hasattr(module, 'fc'): + self.fc.weight = module.weight + + @property + def weight(self): + return self.fc.weight + + def state_dict(self, destination=None, prefix='', keep_vars=False): + state_dict = super().state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars) + if f'{prefix}fc.weight' in state_dict: + state_dict[f'{prefix}weight'] = state_dict.pop(f'{prefix}fc.weight') + return state_dict + + def _fp32_forward(self, hidden_states): + if isinstance(self.fc, tuple(_LLMC_LINEAR_TYPES_)): + logits = self.fc(hidden_states.type(torch.float32), dtype=torch.float32) + else: + org_dtype = self.fc.weight.dtype + self.fc.weight.data = self.fc.weight.data.to(torch.float32) + logits = self.fc(hidden_states.type(torch.float32)) + self.fc.weight.data = self.fc.weight.data.to(org_dtype) + return logits + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + # compute gating score + hidden_states = hidden_states.view(-1, h) + + logits = self._fp32_forward(hidden_states) + + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f'insupportable scoring function for MoE gating: {self.scoring_func}' + ) + + # select top-k experts + if self.topk_method == 'greedy': + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == 'group_limited_greedy': + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + # norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + # expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + + @classmethod + @torch.no_grad() + def new(cls, module): + new_module = cls(module) + return new_module + + def __repr__(self): + return ( + 'LlmcDeepSeekV2MoEGate(' + + f'fc={self.fc})' + ) + + class LlmcActFn(nn.Module): def __init__(self, module, a_qdq) -> None: super().__init__() @@ -848,8 +985,8 @@ def forward(self, x, dtype=None): def convert_dtype(self, dtype): self.tmp_weight.data = self.tmp_weight.data.to(dtype) - if self.bias is not None: - self.bias.data = self.bias.data.to(dtype) + if self.tmp_bias is not None: + self.tmp_bias.data = self.tmp_bias.data.to(dtype) @classmethod @torch.no_grad() @@ -876,9 +1013,6 @@ def get_func_name(cls, any_callable): return any_callable.func.__name__ return any_callable.__name__ - def register_activation_parameters(self, named_parameters): - pass - def __repr__(self): return ( f'FakeQuantLinear(in_features={self.in_features},' @@ -909,15 +1043,26 @@ def __init__(self, weight, bias, ori_module, a_qdq): self.buf_rotate = False @torch.no_grad() - def forward(self, x): + def forward(self, x, dtype=None): if hasattr(self, 'buf_rotate') and self.buf_rotate: x = self.rotater.rotate(x) if self.a_qdq is not None: x = self.a_qdq(x, self) + + org_dtype = self.weight.data.dtype + if dtype is not None: + self.convert_dtype(dtype) + x = torch.functional.F.linear(x, self.weight, self.bias) + self.convert_dtype(org_dtype) return x + def convert_dtype(self, dtype): + self.weight.data = self.weight.data.to(dtype) + if self.bias is not None: + self.bias.data = self.bias.data.to(dtype) + @classmethod @torch.no_grad() def new(cls, module, w_qdq, a_qdq, debug_print={}): @@ -957,137 +1102,6 @@ def __repr__(self): ) -class LlmcDeepSeekV2MoEGate(nn.Module): - def __init__(self, module): - super().__init__() - self.config = module.config - self.top_k = module.config.num_experts_per_tok - self.n_routed_experts = module.config.n_routed_experts - self.routed_scaling_factor = module.config.routed_scaling_factor - self.scoring_func = module.config.scoring_func - self.alpha = module.config.aux_loss_alpha - self.seq_aux = module.config.seq_aux - self.topk_method = module.config.topk_method - self.n_group = module.config.n_group - self.topk_group = module.config.topk_group - - # topk selection algorithm - self.norm_topk_prob = module.config.norm_topk_prob - self.gating_dim = module.config.hidden_size - self.fc = getattr(module, 'fc', - nn.Linear(self.gating_dim, self.n_routed_experts, bias=False)) - if not hasattr(module, 'fc'): - self.fc.weight = module.weight - - @property - def weight(self): - return self.fc.weight - - def _fp32_forward(self, hidden_states): - if isinstance(self.fc, tuple(_LLMC_LINEAR_TYPES_)): - logits = self.fc(hidden_states.type(torch.float32), dtype=torch.float32) - else: - org_dtype = self.fc.weight.dtype - self.fc.weight.data = self.fc.weight.data.to(torch.float32) - logits = self.fc(hidden_states.type(torch.float32)) - self.fc.weight.data = self.fc.weight.data.to(org_dtype) - return logits - - def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - # compute gating score - hidden_states = hidden_states.view(-1, h) - - logits = self._fp32_forward(hidden_states) - - if self.scoring_func == 'softmax': - scores = logits.softmax(dim=-1, dtype=torch.float32) - else: - raise NotImplementedError( - f'insupportable scoring function for MoE gating: {self.scoring_func}' - ) - - # select top-k experts - if self.topk_method == 'greedy': - topk_weight, topk_idx = torch.topk( - scores, k=self.top_k, dim=-1, sorted=False - ) - elif self.topk_method == 'group_limited_greedy': - group_scores = ( - scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk( - group_scores, k=self.topk_group, dim=-1, sorted=False - )[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand( - bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group - ) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weight, topk_idx = torch.topk( - tmp_scores, k=self.top_k, dim=-1, sorted=False - ) - - # norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - else: - topk_weight = topk_weight * self.routed_scaling_factor - # expert-level computation auxiliary loss - if self.training and self.alpha > 0.0: - scores_for_aux = scores - aux_topk = self.top_k - # always compute aux loss based on the naive greedy topk method - topk_idx_for_aux_loss = topk_idx.view(bsz, -1) - if self.seq_aux: - scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) - ce = torch.zeros( - bsz, self.n_routed_experts, device=hidden_states.device - ) - ce.scatter_add_( - 1, - topk_idx_for_aux_loss, - torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), - ).div_(seq_len * aux_topk / self.n_routed_experts) - aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( - dim=1 - ).mean() * self.alpha - else: - mask_ce = F.one_hot( - topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts - ) - ce = mask_ce.float().mean(0) - Pi = scores_for_aux.mean(0) - fi = ce * self.n_routed_experts - aux_loss = (Pi * fi).sum() * self.alpha - else: - aux_loss = None - - return topk_idx, topk_weight, aux_loss - - @classmethod - @torch.no_grad() - def new(cls, module): - new_module = cls(module) - new_module.zeros_shape = None - new_module.zeros_dtype = None - return new_module - - def __repr__(self): - return ( - 'LlmcDeepSeekV2MoEGate(' - + f'fc={self.fc})' - ) - - class VllmRealQuantLinear(nn.Module): def __init__(self, weight, bias, scales, input_scale, need_pack): super().__init__() diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 1dd3779fb..f9d7b1f64 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -43,6 +43,11 @@ def get_tensor_range(self, tensor, args={}): raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') def get_running_tensor_range(self, act_tensors, alpha, args): + assert len(act_tensors) > 0, ( + 'Calibration data is insufficient. Please provide more data to ensure ' + 'all experts in the MOE receive an adequate number of tokens.' + ) + runing_min_vals, runing_max_vals = [], [] if isinstance(act_tensors[0], tuple): # Handle multiple inputs by stacking tensors.