Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,9 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_float(matmul, matmul_name, root_input, **kwargs)

if matmul.bits != 4:
raise NotImplementedError(f"{matmul.bits} bits precision is not currently supported in QDQ format.")

dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul)

Expand Down Expand Up @@ -994,6 +997,10 @@ def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_i
# Create dummy PackedMatMul class
class PackedMatMul:
def __init__(self):
if q_matmul.bits != k_matmul.bits or q_matmul.bits != v_matmul.bits:
raise ValueError("All MatMuls must have the same bits for packed MatMul.")
if q_matmul.group_size != k_matmul.group_size or q_matmul.group_size != v_matmul.group_size:
raise ValueError("All MatMuls must have the same group size for packed MatMul.")
self.qweight = torch.cat([q_matmul.qweight, k_matmul.qweight, v_matmul.qweight], dim=0)
self.scales = torch.cat([q_matmul.scales, k_matmul.scales, v_matmul.scales], dim=0)
self.qzeros = torch.cat([q_matmul.qzeros, k_matmul.qzeros, v_matmul.qzeros], dim=0)
Expand Down
Loading
Loading