Skip to content

Commit 4aee929

Browse files
authored
[Builder] Add support for Olive quantized models (#1647)
- Support new `"olive"` quant type - Weight and zero-point packings are the same as gptq. No g_idx. - Similar to the `k_quant` mixed precision `int4_algo`, select matmuls can be in 8 bits. - Currently, we ensure that the `q_proj`, `k_proj`, `v_proj` matmuls use the same configuration (bits and group_size) so that they can be merged without issues. - The modules are generalized to remove the requirement that all matmuls in a layer must have the same bits and group_size. - `quant_weight` and `dequant_weight` support no `g_idx` by using `repeat_interleave`. Otherwise, we have to create a trivial g_idx like the quark model does. - `pack_ort_format` supports 8 bit packing.
1 parent 3bac249 commit 4aee929

File tree

2 files changed

+193
-116
lines changed

2 files changed

+193
-116
lines changed

src/python/py/models/builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,10 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
901901
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
902902
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
903903
return self.make_matmul_float(matmul, matmul_name, root_input, **kwargs)
904+
905+
if matmul.bits != 4:
906+
# code below assume 4 bits with hard coded shapes (* 2)
907+
raise NotImplementedError(f"{matmul.bits} bits precision is not currently supported in QDQ format.")
904908

905909
dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul)
906910

@@ -994,6 +998,10 @@ def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_i
994998
# Create dummy PackedMatMul class
995999
class PackedMatMul:
9961000
def __init__(self):
1001+
if q_matmul.bits != k_matmul.bits or q_matmul.bits != v_matmul.bits:
1002+
raise ValueError("All MatMuls must have the same bits for packed MatMul.")
1003+
if q_matmul.group_size != k_matmul.group_size or q_matmul.group_size != v_matmul.group_size:
1004+
raise ValueError("All MatMuls must have the same group size for packed MatMul.")
9971005
self.qweight = torch.cat([q_matmul.qweight, k_matmul.qweight, v_matmul.qweight], dim=0)
9981006
self.scales = torch.cat([q_matmul.scales, k_matmul.scales, v_matmul.scales], dim=0)
9991007
self.qzeros = torch.cat([q_matmul.qzeros, k_matmul.qzeros, v_matmul.qzeros], dim=0)

0 commit comments

Comments
 (0)