Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 151 additions & 137 deletions llmc/compression/quantization/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,143 @@ def forward(
return attn_output, attn_weights, past_key_value


class LlmcDeepSeekV2MoEGate(nn.Module):
def __init__(self, module):
super().__init__()
self.config = module.config
self.top_k = module.config.num_experts_per_tok
self.n_routed_experts = module.config.n_routed_experts
self.routed_scaling_factor = module.config.routed_scaling_factor
self.scoring_func = module.config.scoring_func
self.alpha = module.config.aux_loss_alpha
self.seq_aux = module.config.seq_aux
self.topk_method = module.config.topk_method
self.n_group = module.config.n_group
self.topk_group = module.config.topk_group

# topk selection algorithm
self.norm_topk_prob = module.config.norm_topk_prob
self.gating_dim = module.config.hidden_size
self.fc = getattr(module, 'fc',
nn.Linear(self.gating_dim, self.n_routed_experts, bias=False))
if not hasattr(module, 'fc'):
self.fc.weight = module.weight

@property
def weight(self):
return self.fc.weight

def state_dict(self, destination=None, prefix='', keep_vars=False):
state_dict = super().state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars)
if f'{prefix}fc.weight' in state_dict:
state_dict[f'{prefix}weight'] = state_dict.pop(f'{prefix}fc.weight')
return state_dict

def _fp32_forward(self, hidden_states):
if isinstance(self.fc, tuple(_LLMC_LINEAR_TYPES_)):
logits = self.fc(hidden_states.type(torch.float32), dtype=torch.float32)
else:
org_dtype = self.fc.weight.dtype
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
logits = self.fc(hidden_states.type(torch.float32))
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
return logits

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# compute gating score
hidden_states = hidden_states.view(-1, h)

logits = self._fp32_forward(hidden_states)

if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1, dtype=torch.float32)
else:
raise NotImplementedError(
f'insupportable scoring function for MoE gating: {self.scoring_func}'
)

# select top-k experts
if self.topk_method == 'greedy':
topk_weight, topk_idx = torch.topk(
scores, k=self.top_k, dim=-1, sorted=False
)
elif self.topk_method == 'group_limited_greedy':
group_scores = (
scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(
group_scores, k=self.topk_group, dim=-1, sorted=False
)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weight, topk_idx = torch.topk(
tmp_scores, k=self.top_k, dim=-1, sorted=False
)

# norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * self.routed_scaling_factor
# expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(
bsz, self.n_routed_experts, device=hidden_states.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
dim=1
).mean() * self.alpha
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@classmethod
@torch.no_grad()
def new(cls, module):
new_module = cls(module)
return new_module

def __repr__(self):
return (
'LlmcDeepSeekV2MoEGate('
+ f'fc={self.fc})'
)


class LlmcActFn(nn.Module):
def __init__(self, module, a_qdq) -> None:
super().__init__()
Expand Down Expand Up @@ -848,8 +985,8 @@ def forward(self, x, dtype=None):

def convert_dtype(self, dtype):
self.tmp_weight.data = self.tmp_weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)
if self.tmp_bias is not None:
self.tmp_bias.data = self.tmp_bias.data.to(dtype)

@classmethod
@torch.no_grad()
Expand All @@ -876,9 +1013,6 @@ def get_func_name(cls, any_callable):
return any_callable.func.__name__
return any_callable.__name__

def register_activation_parameters(self, named_parameters):
pass

def __repr__(self):
return (
f'FakeQuantLinear(in_features={self.in_features},'
Expand Down Expand Up @@ -909,15 +1043,26 @@ def __init__(self, weight, bias, ori_module, a_qdq):
self.buf_rotate = False

@torch.no_grad()
def forward(self, x):
def forward(self, x, dtype=None):
if hasattr(self, 'buf_rotate') and self.buf_rotate:
x = self.rotater.rotate(x)

if self.a_qdq is not None:
x = self.a_qdq(x, self)

org_dtype = self.weight.data.dtype
if dtype is not None:
self.convert_dtype(dtype)

x = torch.functional.F.linear(x, self.weight, self.bias)
self.convert_dtype(org_dtype)
return x

def convert_dtype(self, dtype):
self.weight.data = self.weight.data.to(dtype)
if self.bias is not None:
self.bias.data = self.bias.data.to(dtype)

@classmethod
@torch.no_grad()
def new(cls, module, w_qdq, a_qdq, debug_print={}):
Expand Down Expand Up @@ -957,137 +1102,6 @@ def __repr__(self):
)


class LlmcDeepSeekV2MoEGate(nn.Module):
def __init__(self, module):
super().__init__()
self.config = module.config
self.top_k = module.config.num_experts_per_tok
self.n_routed_experts = module.config.n_routed_experts
self.routed_scaling_factor = module.config.routed_scaling_factor
self.scoring_func = module.config.scoring_func
self.alpha = module.config.aux_loss_alpha
self.seq_aux = module.config.seq_aux
self.topk_method = module.config.topk_method
self.n_group = module.config.n_group
self.topk_group = module.config.topk_group

# topk selection algorithm
self.norm_topk_prob = module.config.norm_topk_prob
self.gating_dim = module.config.hidden_size
self.fc = getattr(module, 'fc',
nn.Linear(self.gating_dim, self.n_routed_experts, bias=False))
if not hasattr(module, 'fc'):
self.fc.weight = module.weight

@property
def weight(self):
return self.fc.weight

def _fp32_forward(self, hidden_states):
if isinstance(self.fc, tuple(_LLMC_LINEAR_TYPES_)):
logits = self.fc(hidden_states.type(torch.float32), dtype=torch.float32)
else:
org_dtype = self.fc.weight.dtype
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
logits = self.fc(hidden_states.type(torch.float32))
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
return logits

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# compute gating score
hidden_states = hidden_states.view(-1, h)

logits = self._fp32_forward(hidden_states)

if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1, dtype=torch.float32)
else:
raise NotImplementedError(
f'insupportable scoring function for MoE gating: {self.scoring_func}'
)

# select top-k experts
if self.topk_method == 'greedy':
topk_weight, topk_idx = torch.topk(
scores, k=self.top_k, dim=-1, sorted=False
)
elif self.topk_method == 'group_limited_greedy':
group_scores = (
scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(
group_scores, k=self.topk_group, dim=-1, sorted=False
)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(
bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weight, topk_idx = torch.topk(
tmp_scores, k=self.top_k, dim=-1, sorted=False
)

# norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * self.routed_scaling_factor
# expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(
bsz, self.n_routed_experts, device=hidden_states.device
)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
dim=1
).mean() * self.alpha
else:
mask_ce = F.one_hot(
topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None

return topk_idx, topk_weight, aux_loss

@classmethod
@torch.no_grad()
def new(cls, module):
new_module = cls(module)
new_module.zeros_shape = None
new_module.zeros_dtype = None
return new_module

def __repr__(self):
return (
'LlmcDeepSeekV2MoEGate('
+ f'fc={self.fc})'
)


class VllmRealQuantLinear(nn.Module):
def __init__(self, weight, bias, scales, input_scale, need_pack):
super().__init__()
Expand Down
5 changes: 5 additions & 0 deletions llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def get_tensor_range(self, tensor, args={}):
raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}')

def get_running_tensor_range(self, act_tensors, alpha, args):
assert len(act_tensors) > 0, (
'Calibration data is insufficient. Please provide more data to ensure '
'all experts in the MOE receive an adequate number of tokens.'
)

runing_min_vals, runing_max_vals = [], []
if isinstance(act_tensors[0], tuple):
# Handle multiple inputs by stacking tensors.
Expand Down
Loading