Skip to content

Commit 837576a

Browse files
authored
Merge pull request #209 from ModelTC/dev_fixbug
Fix gptq bug
2 parents ee61b0f + e58bfb4 commit 837576a

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
109109
'softmax_a_qdq': partial(self.a_qdq, aquantizer=self.aquantizer)
110110
if self.quant_softmax else None
111111
}
112+
112113
elif mode == 'quant_act_fn':
113114
params_dict = {
114115
'a_qdq': partial(self.a_qdq, aquantizer=self.aquantizer)
@@ -324,7 +325,6 @@ def replace_attention(self, block, extra_modules):
324325
def replace_moe_gate(self, block):
325326
moe_gate_layer = self.model.get_moe_gate(block)
326327
if moe_gate_layer is not None:
327-
logger.info(moe_gate_layer)
328328
moe_gate_module = _LLMC_MOE_GATE_MAP_[self.config['model']['type']]
329329
layers_dict = {'layers': moe_gate_layer}
330330
self.model.replace_module_subset(
@@ -333,7 +333,7 @@ def replace_moe_gate(self, block):
333333
layers_dict,
334334
self.block_idx,
335335
self.get_replacement_params(
336-
mode='quant_moegate', w_only=self.w_only, name=None
336+
mode=None, w_only=self.w_only, name=None
337337
),
338338
)
339339

@@ -554,13 +554,13 @@ def register_act_qparams(self, layers_dict, act_tensors):
554554
):
555555
scales = scales.cuda()
556556
dist.all_reduce(scales, op=dist.ReduceOp.SUM)
557-
scales = (scales / world_size).cpu()
557+
scales = (scales / world_size)
558558

559559
for name, layer in layers_dict.items():
560560
layer.register_buffer(f'buf_act_scales_{i}', scales)
561-
layer.register_buffer(f'buf_act_zeros_{i}', zeros)
562-
layer.register_buffer(f'buf_act_qmin_{i}', qmin)
563-
layer.register_buffer(f'buf_act_qmax_{i}', qmax)
561+
layer.register_buffer(f'buf_act_zeros_{i}', zeros.cuda())
562+
layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())
563+
layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())
564564

565565
@torch.no_grad()
566566
def apply_scale(self, scales, prev_op, layers):
@@ -808,7 +808,7 @@ def deploy(self, quant_format, keep_device=False):
808808
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
809809
keep_device=keep_device
810810
)
811-
self.set_non_linear_mode(quant_format, self.model.model, False)
811+
self.set_non_linear_mode(quant_format, self.model.model, False)
812812

813813
if self.model.vlm_model is not None:
814814
logger.info(f'Now, the vlm_model is: {self.model.vlm_model}')

llmc/compression/quantization/gptq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def collect_model_qparams(self):
308308
for i in range(len(self.blocks)):
309309
block = self.blocks[i]
310310
block = block.cuda()
311+
self.replace_moe_gate(block)
311312
self.collect_block_qparams(block)
312313
block = block.cpu()
313314

llmc/compression/quantization/module_utils.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def __init__(self, weight, bias, ori_module, w_qdq, a_qdq):
818818
self.dynamic_quant_weight = False
819819
self.dynamic_quant_tmp_weight = False
820820

821-
def forward(self, x):
821+
def forward(self, x, dtype=None):
822822
if hasattr(self, 'buf_rotate') and self.buf_rotate:
823823
x = self.rotater.rotate(x)
824824

@@ -837,10 +837,20 @@ def forward(self, x):
837837
elif self.dynamic_quant_tmp_weight:
838838
self.tmp_weight = self.w_qdq(self)
839839

840+
org_dtype = self.tmp_weight.data.dtype
841+
if dtype is not None:
842+
self.convert_dtype(dtype)
843+
840844
x = torch.functional.F.linear(x, self.tmp_weight, self.tmp_bias)
841845

846+
self.convert_dtype(org_dtype)
842847
return x
843848

849+
def convert_dtype(self, dtype):
850+
self.tmp_weight.data = self.tmp_weight.data.to(dtype)
851+
if self.bias is not None:
852+
self.bias.data = self.bias.data.to(dtype)
853+
844854
@classmethod
845855
@torch.no_grad()
846856
def new(cls, module, w_qdq, a_qdq):
@@ -964,21 +974,32 @@ def __init__(self, module):
964974
# topk selection algorithm
965975
self.norm_topk_prob = module.config.norm_topk_prob
966976
self.gating_dim = module.config.hidden_size
967-
self.fc = nn.Linear(self.gating_dim, self.n_routed_experts, bias=False)
968-
self.fc.weight = module.weight
977+
self.fc = getattr(module, 'fc',
978+
nn.Linear(self.gating_dim, self.n_routed_experts, bias=False))
979+
if not hasattr(module, 'fc'):
980+
self.fc.weight = module.weight
969981

970982
@property
971983
def weight(self):
972984
return self.fc.weight
973985

986+
def _fp32_forward(self, hidden_states):
987+
if isinstance(self.fc, tuple(_LLMC_LINEAR_TYPES_)):
988+
logits = self.fc(hidden_states.type(torch.float32), dtype=torch.float32)
989+
else:
990+
org_dtype = self.fc.weight.dtype
991+
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
992+
logits = self.fc(hidden_states.type(torch.float32))
993+
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
994+
return logits
995+
974996
def forward(self, hidden_states):
975997
bsz, seq_len, h = hidden_states.shape
976998
# compute gating score
977999
hidden_states = hidden_states.view(-1, h)
978-
org_dtype = self.fc.weight.dtype
979-
self.fc.weight.data = self.fc.weight.data.to(torch.float32)
980-
logits = self.fc(hidden_states.type(torch.float32))
981-
self.fc.weight.data = self.fc.weight.data.to(org_dtype)
1000+
1001+
logits = self._fp32_forward(hidden_states)
1002+
9821003
if self.scoring_func == 'softmax':
9831004
scores = logits.softmax(dim=-1, dtype=torch.float32)
9841005
else:

0 commit comments

Comments
 (0)