diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 65e01797f7..c06c439d10 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -901,6 +901,10 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs): # print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.") # print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.") return self.make_matmul_float(matmul, matmul_name, root_input, **kwargs) + + if matmul.bits != 4: + # code below assume 4 bits with hard coded shapes (* 2) + raise NotImplementedError(f"{matmul.bits} bits precision is not currently supported in QDQ format.") dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul) @@ -994,6 +998,10 @@ def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_i # Create dummy PackedMatMul class class PackedMatMul: def __init__(self): + if q_matmul.bits != k_matmul.bits or q_matmul.bits != v_matmul.bits: + raise ValueError("All MatMuls must have the same bits for packed MatMul.") + if q_matmul.group_size != k_matmul.group_size or q_matmul.group_size != v_matmul.group_size: + raise ValueError("All MatMuls must have the same group size for packed MatMul.") self.qweight = torch.cat([q_matmul.qweight, k_matmul.qweight, v_matmul.qweight], dim=0) self.scales = torch.cat([q_matmul.scales, k_matmul.scales, v_matmul.scales], dim=0) self.qzeros = torch.cat([q_matmul.qzeros, k_matmul.qzeros, v_matmul.qzeros], dim=0) diff --git a/src/python/py/models/quantized_model.py b/src/python/py/models/quantized_model.py index 567e555919..535348b233 100644 --- a/src/python/py/models/quantized_model.py +++ b/src/python/py/models/quantized_model.py @@ -21,7 +21,7 @@ class QuantizedTensorModule: - def __init__(self, bits, group_size): + def __init__(self): self.qweight = None self.scales = None self.qzeros = None @@ -30,13 +30,17 @@ def __init__(self, bits, group_size): self.in_features = 0 self.out_features = 0 - self.bits = bits - self._group_size = group_size + self.bits = None + self._group_size = None @property def group_size(self): return self._group_size if self._group_size != -1 else self.in_features + @group_size.setter + def group_size(self, value): + self._group_size = value + def __str__(self): qweight = f"qweight = {self.qweight.shape}, {self.qweight}\n" scales = f"scales = {self.scales.shape}, {self.scales}\n" @@ -57,35 +61,33 @@ def __init__(self): self.bias = None class QuantizedAttention: - def __init__(self, bits, group_size): - self.q_proj = QuantizedTensorModule(bits, group_size) - self.k_proj = QuantizedTensorModule(bits, group_size) - self.v_proj = QuantizedTensorModule(bits, group_size) - self.o_proj = QuantizedTensorModule(bits, group_size) + def __init__(self): + self.q_proj = QuantizedTensorModule() + self.k_proj = QuantizedTensorModule() + self.v_proj = QuantizedTensorModule() + self.o_proj = QuantizedTensorModule() self.rotary_emb = TensorModule() self.k_norm = TensorModule() self.q_norm = TensorModule() class QuantizedMLP: - def __init__(self, bits, group_size): - self.gate_proj = QuantizedTensorModule(bits, group_size) - self.up_proj = QuantizedTensorModule(bits, group_size) - self.down_proj = QuantizedTensorModule(bits, group_size) - self.fc1 = QuantizedTensorModule(bits, group_size) - self.fc2 = QuantizedTensorModule(bits, group_size) + def __init__(self): + self.gate_proj = QuantizedTensorModule() + self.up_proj = QuantizedTensorModule() + self.down_proj = QuantizedTensorModule() + self.fc1 = QuantizedTensorModule() + self.fc2 = QuantizedTensorModule() class QuantizedDecoderLayer: - def __init__(self, layer_id, bits, group_size): + def __init__(self, layer_id): self.layer_id = layer_id self.input_layernorm = TensorModule() - self.self_attn = QuantizedAttention(bits, group_size) + self.self_attn = QuantizedAttention() self.post_attention_layernorm = TensorModule() self.pre_feedforward_layernorm = TensorModule() self.post_feedforward_layernorm = TensorModule() - self.mlp = QuantizedMLP(bits, group_size) - self.bits = bits - self.group_size = group_size + self.mlp = QuantizedMLP() def is_empty(self): return self.input_layernorm.weight is None @@ -144,19 +146,18 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme # Chatglm3, e.g., transformer.encoder.layers.0.input_layernorm.weight name = name.replace("transformer.encoder", "model") layer_id = int(name.split(".")[2]) - module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id, local_bits, local_group_size)) - if local_bits != module.bits or local_group_size != module.group_size: - raise NotImplementedError("Setting different bits or group sizes for various linear modules within a decoder layer is not yet supported in the model builder.") + module = self.layers.setdefault(layer_id, QuantizedDecoderLayer(layer_id)) # Map weights and biases of norm, attention, and feed-forward network # Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj # If model uses q_norm and k_norm, graph order is input_layernorm --> q_norm/q_proj/k_norm/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj + tensor_map = {} if bool(re.match(r"^model.layers\.\d+\.input_layernorm\.weight$", name)): # model.layers.layer_id.input_layernorm.weight - module.input_layernorm.weight = tensor + tensor_map["input_layernorm.weight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.input_layernorm\.bias$", name)): # model.layers.layer_id.input_layernorm.bias - module.input_layernorm.bias = tensor + tensor_map["input_layernorm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.rotary_emb\.inv_freq$", name)): # model.layers.layer_id.self_attn.rotary_emb.inv_freq # Skip rotary embedding weights since they can be re-calculated when looping through the model @@ -164,173 +165,173 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.q?weight$", name)): # model.layers.layer_id.self_attn.q_proj.weight # model.layers.layer_id.self_attn.q_proj.qweight - module.self_attn.q_proj.qweight = tensor + tensor_map["self_attn.q_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.q_proj.scales # model.layers.layer_id.self_attn.q_proj.weight_scale - module.self_attn.q_proj.scales = tensor + tensor_map["self_attn.q_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.q_proj.qzeros # model.layers.layer_id.self_attn.q_proj.weight_zero_point - module.self_attn.q_proj.qzeros = tensor + tensor_map["self_attn.q_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.q_proj.g_idx - module.self_attn.q_proj.g_idx = tensor + tensor_map["self_attn.q_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.bias$", name)): # model.layers.layer_id.self_attn.q_proj.bias - module.self_attn.q_proj.bias = tensor + tensor_map["self_attn.q_proj.bias"] = tensor elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.q_norm\.weight$", name)): # model.layers.layer_id.self_attn.q_norm.weight - module.self_attn.q_norm.weight = tensor + tensor_map["self_attn.q_norm.weight"] = tensor elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.q_norm\.bias$", name)): # model.layers.layer_id.self_attn.q_norm.bias - module.self_attn.q_norm.bias = tensor + tensor_map["self_attn.q_norm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.q?weight$", name)): # model.layers.layer_id.self_attn.k_proj.qweight # model.layers.layer_id.self_attn.k_proj.weight - module.self_attn.k_proj.qweight = tensor + tensor_map["self_attn.k_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.k_proj.scales # model.layers.layer_id.self_attn.k_proj.weight_scale - module.self_attn.k_proj.scales = tensor + tensor_map["self_attn.k_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.k_proj.qzeros # model.layers.layer_id.self_attn.k_proj.weight_zero_point - module.self_attn.k_proj.qzeros = tensor + tensor_map["self_attn.k_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.k_proj.g_idx - module.self_attn.k_proj.g_idx = tensor + tensor_map["self_attn.k_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.bias$", name)): # model.layers.layer_id.self_attn.k_proj.bias - module.self_attn.k_proj.bias = tensor + tensor_map["self_attn.k_proj.bias"] = tensor elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.k_norm\.weight$", name)): # model.layers.layer_id.self_attn.k_norm.weight - module.self_attn.k_norm.weight = tensor + tensor_map["self_attn.k_norm.weight"] = tensor elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.k_norm\.bias$", name)): # model.layers.layer_id.self_attn.k_norm.bias - module.self_attn.k_norm.bias = tensor + tensor_map["self_attn.k_norm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.q?weight$", name)): # model.layers.layer_id.self_attn.v_proj.qweight # model.layers.layer_id.self_attn.v_proj.weight - module.self_attn.v_proj.qweight = tensor + tensor_map["self_attn.v_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.v_proj.scales # model.layers.layer_id.self_attn.v_proj.weight_scale - module.self_attn.v_proj.scales = tensor + tensor_map["self_attn.v_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.v_proj.qzeros # model.layers.layer_id.self_attn.v_proj.weight_zero_point - module.self_attn.v_proj.qzeros = tensor + tensor_map["self_attn.v_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.g_idx$", name)): # model.layers.layer_id.self_attn.v_proj.g_idx - module.self_attn.v_proj.g_idx = tensor + tensor_map["self_attn.v_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.bias$", name)): # model.layers.layer_id.self_attn.v_proj.bias - module.self_attn.v_proj.bias = tensor + tensor_map["self_attn.v_proj.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.q?weight$", name)): # model.layers.layer_id.self_attn.o_proj.qweight # model.layers.layer_id.self_attention.dense.qweight - module.self_attn.o_proj.qweight = tensor + tensor_map["self_attn.o_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.o_proj.scales # model.layers.layer_id.self_attention.dense.scales # model.layers.layer_id.self_attn.o_proj.weight_scale # model.layers.layer_id.self_attention.dense.weight_scale - module.self_attn.o_proj.scales = tensor + tensor_map["self_attn.o_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.o_proj.qzeros # model.layers.layer_id.self_attention.dense.qzeros # model.layers.layer_id.self_attn.o_proj.weight_zero_point # model.layers.layer_id.self_attention.dense.weight_zero_point - module.self_attn.o_proj.qzeros = tensor + tensor_map["self_attn.o_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.g_idx$", name)): # model.layers.layer_id.self_attn.o_proj.g_idx # model.layers.layer_id.self_attention.dense.g_idx - module.self_attn.o_proj.g_idx = tensor + tensor_map["self_attn.o_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.o_proj|self_attention.dense)\.bias$", name)): # model.layers.layer_id.self_attn.o_proj.bias # model.layers.layer_id.self_attention.dense.bias - module.self_attn.o_proj.bias = tensor + tensor_map["self_attn.o_proj.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.post_attention_layernorm\.weight$", name)): # model.layers.layer_id.post_attention_layernorm.weight - module.post_attention_layernorm.weight = tensor + tensor_map["post_attention_layernorm.weight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.post_attention_layernorm\.bias$", name)): # model.layers.layer_id.post_attention_layernorm.bias - module.post_attention_layernorm.bias = tensor + tensor_map["post_attention_layernorm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.pre_feedforward_layernorm\.weight$", name)): # model.layers.layer_id.pre_feedforward_layernorm.weight - module.pre_feedforward_layernorm.weight = tensor + tensor_map["pre_feedforward_layernorm.weight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.pre_feedforward_layernorm\.bias$", name)): # model.layers.layer_id.pre_feedforward_layernorm.bias - module.pre_feedforward_layernorm.bias = tensor + tensor_map["pre_feedforward_layernorm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.post_feedforward_layernorm\.weight$", name)): # model.layers.layer_id.post_feedforward_layernorm.weight - module.post_feedforward_layernorm.weight = tensor + tensor_map["post_feedforward_layernorm.weight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.post_feedforward_layernorm\.bias$", name)): # model.layers.layer_id.post_feedforward_layernorm.bias - module.post_feedforward_layernorm.bias = tensor + tensor_map["post_feedforward_layernorm.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.q?weight$", name)): # model.layers.layer_id.mlp.gate_proj.qweight # model.layers.layer_id.mlp.gate_proj.weight - module.mlp.gate_proj.qweight = tensor + tensor_map["mlp.gate_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.gate_proj.scales # model.layers.layer_id.mlp.gate_proj.weight_scale - module.mlp.gate_proj.scales = tensor + tensor_map["mlp.gate_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.gate_proj.qzeros # model.layers.layer_id.mlp.gate_proj.weight_zero_point - module.mlp.gate_proj.qzeros = tensor + tensor_map["mlp.gate_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.g_idx$", name)): # model.layers.layer_id.mlp.gate_proj.g_idx - module.mlp.gate_proj.g_idx = tensor + tensor_map["mlp.gate_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.gate_proj\.bias$", name)): # model.layers.layer_id.mlp.gate_proj.bias - module.mlp.gate_proj.bias = tensor + tensor_map["mlp.gate_proj.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.q?weight$", name)): # model.layers.layer_id.mlp.up_proj.qweight # model.layers.layer_id.mlp.up_proj.weight - module.mlp.up_proj.qweight = tensor + tensor_map["mlp.up_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.up_proj.scales # model.layers.layer_id.mlp.up_proj.weight_scale - module.mlp.up_proj.scales = tensor + tensor_map["mlp.up_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.up_proj.qzeros # model.layers.layer_id.mlp.up_proj.weight_zero_point - module.mlp.up_proj.qzeros = tensor + tensor_map["mlp.up_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.g_idx$", name)): # model.layers.layer_id.mlp.up_proj.g_idx - module.mlp.up_proj.g_idx = tensor + tensor_map["mlp.up_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.up_proj\.bias$", name)): # model.layers.layer_id.mlp.up_proj.bias - module.mlp.up_proj.bias = tensor + tensor_map["mlp.up_proj.bias"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.q?weight$", name)): # model.layers.layer_id.mlp.down_proj.qweight # model.layers.layer_id.mlp.dense_4h_to_h.qweight # model.layers.layer_id.mlp.down_proj.weight # model.layers.layer_id.mlp.dense_4h_to_h.weight - module.mlp.down_proj.qweight = tensor + tensor_map["mlp.down_proj.qweight"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.down_proj.scales # model.layers.layer_id.mlp.dense_4h_to_h.scales # model.layers.layer_id.mlp.down_proj.weight_scale # model.layers.layer_id.mlp.dense_4h_to_h.weight_scale - module.mlp.down_proj.scales = tensor + tensor_map["mlp.down_proj.scales"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.down_proj.qzeros # model.layers.layer_id.mlp.dense_4h_to_h.qzeros # model.layers.layer_id.mlp.down_proj.weight_zero_point # model.layers.layer_id.mlp.dense_4h_to_h.weight_zero_point - module.mlp.down_proj.qzeros = tensor + tensor_map["mlp.down_proj.qzeros"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.g_idx$", name)): # model.layers.layer_id.mlp.down_proj.g_idx # model.layers.layer_id.mlp.dense_4h_to_h.g_idx - module.mlp.down_proj.g_idx = tensor + tensor_map["mlp.down_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(down_proj|dense_4h_to_h)\.bias$", name)): # model.layers.layer_id.mlp.down_proj.bias # model.layers.layer_id.mlp.dense_4h_to_h.bias - module.mlp.down_proj.bias = tensor + tensor_map["mlp.down_proj.bias"] = tensor # Match against fused layers elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.q?weight$", name)): # model.layers.layer_id.self_attn.qkv_proj.qweight @@ -339,75 +340,86 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme # model.layers.layer_id.self_attention.query_key_value.weight q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "quark"} else q_size kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "quark"} else kv_size - module.self_attn.q_proj.qweight = tensor[:, : q_dim] - module.self_attn.k_proj.qweight = tensor[:, q_dim : q_dim + kv_dim] - module.self_attn.v_proj.qweight = tensor[:, q_dim + kv_dim :] + tensor_map["self_attn.q_proj.qweight"] = tensor[:, : q_dim] + tensor_map["self_attn.k_proj.qweight"] = tensor[:, q_dim : q_dim + kv_dim] + tensor_map["self_attn.v_proj.qweight"] = tensor[:, q_dim + kv_dim :] elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(scales|weight_scale)$", name)): # model.layers.layer_id.self_attn.qkv_proj.scales # model.layers.layer_id.self_attention.query_key_value.scales # model.layers.layer_id.self_attn.qkv_proj.weight_scale # model.layers.layer_id.self_attention.query_key_value.weight_scale - module.self_attn.q_proj.scales = tensor[:, : q_size] - module.self_attn.k_proj.scales = tensor[:, q_size : q_size + kv_size] - module.self_attn.v_proj.scales = tensor[:, q_size + kv_size :] + tensor_map["self_attn.q_proj.scales"] = tensor[:, : q_size] + tensor_map["self_attn.k_proj.scales"] = tensor[:, q_size : q_size + kv_size] + tensor_map["self_attn.v_proj.scales"] = tensor[:, q_size + kv_size :] elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.self_attn.qkv_proj.qzeros # model.layers.layer_id.self_attention.query_key_value.qzeros # model.layers.layer_id.self_attn.qkv_proj.weight_zero_point # model.layers.layer_id.self_attention.query_key_value.weight_zero_point - q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else q_size - kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else kv_size - module.self_attn.q_proj.qzeros = tensor[:, : q_dim] - module.self_attn.k_proj.qzeros = tensor[:, q_dim : q_dim + kv_dim] - module.self_attn.v_proj.qzeros = tensor[:, q_dim + kv_dim :] + q_dim = q_size // (32 // local_bits) if quant_type in {"awq", "gptq", "olive", "quark"} else q_size + kv_dim = kv_size // (32 // local_bits) if quant_type in {"awq", "gptq", "olive", "quark"} else kv_size + tensor_map["self_attn.q_proj.qzeros"] = tensor[:, : q_dim] + tensor_map["self_attn.k_proj.qzeros"] = tensor[:, q_dim : q_dim + kv_dim] + tensor_map["self_attn.v_proj.qzeros"] = tensor[:, q_dim + kv_dim :] elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.g_idx$", name)): # model.layers.layer_id.self_attn.qkv_proj.g_idx # model.layers.layer_id.self_attention.query_key_value.g_idx - module.self_attn.q_proj.g_idx = tensor - module.self_attn.k_proj.g_idx = tensor - module.self_attn.v_proj.g_idx = tensor + tensor_map["self_attn.q_proj.g_idx"] = tensor + tensor_map["self_attn.k_proj.g_idx"] = tensor + tensor_map["self_attn.v_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.(self_attn.qkv_proj|self_attention.query_key_value)\.bias$", name)): # model.layers.layer_id.self_attn.qkv_proj.bias # model.layers.layer_id.self_attention.query_key_value.bias - module.self_attn.q_proj.bias = tensor[: q_size] - module.self_attn.k_proj.bias = tensor[q_size : q_size + kv_size] - module.self_attn.v_proj.bias = tensor[q_size + kv_size : ] + tensor_map["self_attn.q_proj.bias"] = tensor[: q_size] + tensor_map["self_attn.k_proj.bias"] = tensor[q_size : q_size + kv_size] + tensor_map["self_attn.v_proj.bias"] = tensor[q_size + kv_size : ] elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.q?weight$", name)): # model.layers.layer_id.mlp.gate_up_proj.qweight # model.layers.layer_id.mlp.dense_h_to_4h.qweight # model.layers.layer_id.mlp.gate_up_proj.weight # model.layers.layer_id.mlp.dense_h_to_4h.weight intermediate_dim = intermediate_size // (32 // local_bits) if quant_type in {"awq", "quark"} else intermediate_size - module.mlp.gate_proj.qweight = tensor[:, : intermediate_dim] - module.mlp.up_proj.qweight = tensor[:, intermediate_dim :] + tensor_map["mlp.gate_proj.qweight"] = tensor[:, : intermediate_dim] + tensor_map["mlp.up_proj.qweight"] = tensor[:, intermediate_dim :] elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(scales|weight_scale)$", name)): # model.layers.layer_id.mlp.gate_up_proj.scales # model.layers.layer_id.mlp.dense_h_to_4h.scales # model.layers.layer_id.mlp.gate_up_proj.weight_scale # model.layers.layer_id.mlp.dense_h_to_4h.weight_scale - module.mlp.gate_proj.scales = tensor[:, : intermediate_size] - module.mlp.up_proj.scales = tensor[:, intermediate_size :] + tensor_map["mlp.gate_proj.scales"] = tensor[:, : intermediate_size] + tensor_map["mlp.up_proj.scales"] = tensor[:, intermediate_size :] elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h|gate_proj)\.(qzeros|weight_zero_point)$", name)): # model.layers.layer_id.mlp.gate_up_proj.qzeros # model.layers.layer_id.mlp.dense_h_to_4h.qzeros # model.layers.layer_id.mlp.gate_up_proj.weight_zero_point # model.layers.layer_id.mlp.dense_h_to_4h.weight_zero_point - intermediate_dim = intermediate_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark"} else intermediate_size - module.mlp.gate_proj.qzeros = tensor[:, : intermediate_dim] - module.mlp.up_proj.qzeros = tensor[:, intermediate_dim :] + intermediate_dim = intermediate_size // (32 // local_bits) if quant_type in {"awq", "gptq", "quark", "olive"} else intermediate_size + tensor_map["mlp.gate_proj.qzeros"] = tensor[:, : intermediate_dim] + tensor_map["mlp.up_proj.qzeros"] = tensor[:, intermediate_dim :] elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.g_idx$", name)): # model.layers.layer_id.mlp.gate_up_proj.g_idx # model.layers.layer_id.mlp.dense_h_to_4h.g_idx - module.mlp.gate_proj.g_idx = tensor - module.mlp.up_proj.g_idx = tensor + tensor_map["mlp.gate_proj.g_idx"] = tensor + tensor_map["mlp.up_proj.g_idx"] = tensor elif bool(re.match(r"^model.layers\.\d+\.mlp.(gate_up_proj|dense_h_to_4h)\.bias$", name)): # model.layers.layer_id.mlp.gate_up_proj.bias # model.layers.layer_id.mlp.dense_h_to_4h.bias - module.mlp.gate_proj.bias = tensor[: intermediate_size] - module.mlp.up_proj.bias = tensor[intermediate_size: ] + tensor_map["mlp.gate_proj.bias"] = tensor[: intermediate_size] + tensor_map["mlp.up_proj.bias"] = tensor[intermediate_size: ] else: raise NotImplementedError(f"{name} in your quantized model is not recognized.") + for tensor_name, tensor_value in tensor_map.items(): + submodule = module + for sub_name in tensor_name.split(".")[:-1]: + submodule = getattr(submodule, sub_name) + if isinstance(submodule, QuantizedTensorModule): + for q_attr, q_value in [("bits", local_bits), ("_group_size", local_group_size)]: + if getattr(submodule, q_attr) is not None and getattr(submodule, q_attr) != q_value: + raise ValueError(f"Quantization {q_attr} mismatch for {name}: expected {getattr(submodule, q_attr)}, got {q_value}.") + setattr(submodule, q_attr, q_value) + setattr(submodule, tensor_name.split(".")[-1], tensor_value) + # Set LM head weights + biases if not already set if isinstance(self.lm_head, TensorModule) and self.lm_head.weight is None: # Embedding and LM head share same weights + biases (lm_head.weight == embedding.weight and lm_head.bias == embedding.bias) @@ -439,9 +451,11 @@ def _initialize_quantized_lm_head(self, bits, group_size): Initialize `QuantizedTensorModule` for LM head if not already set """ if not isinstance(self.lm_head, QuantizedTensorModule): - q_lm_head = QuantizedTensorModule(bits, group_size) + q_lm_head = QuantizedTensorModule() q_lm_head.qweight = self.lm_head.weight q_lm_head.bias = self.lm_head.bias + q_lm_head.bits = bits + q_lm_head.group_size = group_size self.lm_head = q_lm_head def set_properties(self): @@ -457,6 +471,11 @@ def set_properties(self): elif self.quant_type == "gptq": self.lm_head.out_features = self.lm_head.qweight.shape[1] self.lm_head.in_features = self.lm_head.g_idx.shape[0] + elif self.quant_type == "olive": + self.lm_head.out_features = self.lm_head.qweight.shape[1] + # expects in_features to be divisible by the packing factor (32 // bits) + # not a new assumption since no code here accounts for padded packed weights + self.lm_head.in_features = self.lm_head.qweight.shape[0] * 32 // self.lm_head.bits else: raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") for module in self.layers: @@ -503,6 +522,23 @@ def set_properties(self): module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1] module.mlp.down_proj.in_features = module.mlp.down_proj.g_idx.shape[0] + elif self.quant_type == "olive": + # Set in_features and out_features + module.self_attn.q_proj.out_features = module.self_attn.q_proj.qweight.shape[1] + module.self_attn.q_proj.in_features = module.self_attn.q_proj.qweight.shape[0] * 32 // module.self_attn.q_proj.bits + module.self_attn.k_proj.out_features = module.self_attn.k_proj.qweight.shape[1] + module.self_attn.k_proj.in_features = module.self_attn.k_proj.qweight.shape[0] * 32 // module.self_attn.k_proj.bits + module.self_attn.v_proj.out_features = module.self_attn.v_proj.qweight.shape[1] + module.self_attn.v_proj.in_features = module.self_attn.v_proj.qweight.shape[0] * 32 // module.self_attn.v_proj.bits + module.self_attn.o_proj.out_features = module.self_attn.o_proj.qweight.shape[1] + module.self_attn.o_proj.in_features = module.self_attn.o_proj.qweight.shape[0] * 32 // module.self_attn.o_proj.bits + module.mlp.gate_proj.out_features = module.mlp.gate_proj.qweight.shape[1] + module.mlp.gate_proj.in_features = module.mlp.gate_proj.qweight.shape[0] * 32 // module.mlp.gate_proj.bits + module.mlp.up_proj.out_features = module.mlp.up_proj.qweight.shape[1] + module.mlp.up_proj.in_features = module.mlp.up_proj.qweight.shape[0] * 32 // module.mlp.up_proj.bits + module.mlp.down_proj.out_features = module.mlp.down_proj.qweight.shape[1] + module.mlp.down_proj.in_features = module.mlp.down_proj.qweight.shape[0] * 32 // module.mlp.down_proj.bits + else: raise NotImplementedError(f"The {self.quant_type} quantization method is not recognized.") @@ -566,33 +602,43 @@ def unpack_on_row_for_2_4_8_bits(self, tensor, bits, transpose): def unpack_on_row(self, tensor, bits, transpose): """ - Unpack tensor by row + Unpack tensor by row. Packed datatype is assumed to be int32. """ if bits in {2, 4, 8}: return self.unpack_on_row_for_2_4_8_bits(tensor, bits, transpose) else: raise NotImplementedError(f"Unpacking for {bits}-bit quantization is not currently supported.") - def pack_on_row_for_2_4_8_bits(self, tensor, bits, transpose): + def pack_on_row_for_2_4_8_bits(self, tensor, bits, transpose, packed_dtype=torch.int32): """ Perform general-purpose packing on 2-bit, 4-bit, or 8-bit tensor """ + packed_bitwidth = torch.iinfo(packed_dtype).bits + values_per_pack = packed_bitwidth // bits + orig_tensor = tensor.T if transpose else tensor + + original_cols = orig_tensor.shape[1] + pad_len = (values_per_pack - (original_cols % values_per_pack)) % values_per_pack + if pad_len > 0: + orig_tensor = torch.nn.functional.pad(orig_tensor, (0, pad_len), "constant", 0) + wf = torch.arange(0, bits).view(1, 1, -1) out = torch.bitwise_right_shift(orig_tensor.unsqueeze(-1), wf) out = torch.bitwise_and(out, 1) - out = out.reshape(orig_tensor.shape[0], -1, 32) - wf1 = torch.arange(0, 32, 1).view(1, 1, -1) + + out = out.reshape(orig_tensor.shape[0], -1, values_per_pack * bits) + wf1 = torch.arange(0, values_per_pack * bits, 1).view(1, 1, -1) out = torch.bitwise_left_shift(out, wf1) - out = out.sum(dim=-1).int() + out = out.sum(dim=-1).to(packed_dtype) return out.T if transpose else out - def pack_on_row(self, tensor, bits, transpose): + def pack_on_row(self, tensor, bits, transpose, packed_dtype=torch.int32): """ Pack tensor by row """ if bits in {2, 4, 8}: - return self.pack_on_row_for_2_4_8_bits(tensor, bits, transpose) + return self.pack_on_row_for_2_4_8_bits(tensor, bits, transpose, packed_dtype) else: raise NotImplementedError(f"Packing for {bits}-bit quantization is not currently supported.") @@ -608,9 +654,13 @@ def dequant_weight(self, module): # De-quantize weight to higher precision scale_zeros = zeros * scales - scale_mat = scales[g_idx] - scale_zeros_mat = scale_zeros[g_idx] - qdq_weight_T = intweight * scale_mat - scale_zeros_mat.half() + if g_idx is not None: + scales = scales[g_idx] + scale_zeros = scale_zeros[g_idx] + elif module.group_size != module.in_features: + scales = scales.repeat_interleave(module.group_size, 0) + scale_zeros = scale_zeros.repeat_interleave(module.group_size, 0) + qdq_weight_T = intweight * scales - scale_zeros.half() # Store unpacked result in `qweight` module.qweight = qdq_weight_T.T @@ -625,9 +675,13 @@ def quant_weight(self, module): g_idx = module.g_idx scale_zeros = zeros * scales - scale_mat = scales[g_idx] - scale_zeros_mat = scale_zeros[g_idx] - intweight_T = torch.round((weight + scale_zeros_mat) / scale_mat).to(torch.int) + if g_idx is not None: + scales = scales[g_idx] + scale_zeros = scale_zeros[g_idx] + elif module.group_size != module.in_features: + scales = scales.repeat_interleave(module.group_size, 0) + scale_zeros = scale_zeros.repeat_interleave(module.group_size, 0) + intweight_T = torch.round((weight + scale_zeros) / scales).to(torch.int) return intweight_T @@ -635,28 +689,28 @@ def pack_ort_format(self, module, intweight): """ Pack `scales`, `qzeros`, and `qweight` to ORT format """ - if module.bits != 4: + if module.bits not in [2, 4, 8]: raise NotImplementedError(f"{module.bits}-bit quantization in ORT is not currently supported by this tool.") intzeros_pt = module.qzeros.T if module.qzeros.dtype == module.scales.dtype else module.qzeros.T.byte() intweight_pt = intweight.byte() + kpack = 8 // module.bits block_size = module.group_size rows, cols = intweight_pt.shape - blob_size = block_size // 2 + blob_size = (block_size + kpack - 1) // kpack k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size pad_len = padded_rows - rows if pad_len > 0: intweight_pt = torch.nn.functional.pad(intweight_pt, (0, 0, 0, pad_len), "constant", 0) - intzeros_pt = torch.nn.functional.pad(intzeros_pt, (0, intzeros_pt.shape[-1] & 1, 0, 0), "constant", 0) if module.qzeros.dtype != module.scales.dtype: - intzeros_pt = (intzeros_pt[:, 0::2]) | (intzeros_pt[:, 1::2] << 4) + intzeros_pt = self.pack_on_row(intzeros_pt, module.bits, transpose=False, packed_dtype=torch.uint8) intzeros_pt = intzeros_pt.reshape(-1) intweight_pt_T = intweight.T - intweight_pt_T = (intweight_pt_T[:, 0::2]) | (intweight_pt_T[:, 1::2] << 4) + intweight_pt_T = self.pack_on_row(intweight_pt_T, module.bits, transpose=False, packed_dtype=torch.uint8) intweight_pt_T = intweight_pt_T.reshape(cols, k_blocks, blob_size) scales_pt = module.scales.T.reshape(-1) @@ -922,6 +976,19 @@ def reverse_reorder_tensor(self, tensor, bits): int_tensor = tensor[:, reverse_order_tensor] return int_tensor +class OliveModel(GPTQModel): + def _load_quant_config(self, quant_attrs): + super()._load_quant_config(quant_attrs) + self.overrides = quant_attrs["config"]["overrides"] or {} + + def get_layer_bits(self, layer_name): + name = ".".join(layer_name.split(".")[:-1]) + return self.overrides.get(name, {}).get("bits", self.global_bits) + + def get_layer_group_size(self, layer_name): + name = ".".join(layer_name.split(".")[:-1]) + return self.overrides.get(name, {}).get("group_size", self.global_group_size) + class QuantModel: @staticmethod def from_pretrained(quant_type, **kwargs): @@ -934,6 +1001,8 @@ def from_pretrained(quant_type, **kwargs): model = AWQModel(quant_type, **kwargs) elif quant_type == "gptq": model = GPTQModel(quant_type, **kwargs) + elif quant_type == "olive": + model = OliveModel(quant_type, **kwargs) elif quant_type == "quark": model = QuarkModel(quant_type, **kwargs) else: