Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,10 @@ def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_float(matmul, matmul_name, root_input, **kwargs)

if matmul.bits != 4:
# code below assume 4 bits with hard coded shapes (* 2)
raise NotImplementedError(f"{matmul.bits} bits precision is not currently supported in QDQ format.")

dequantize_output = self.make_dequantize_linear(f"{matmul_name}/DequantizeLinear", matmul)

Expand Down Expand Up @@ -994,6 +998,10 @@ def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_i
# Create dummy PackedMatMul class
class PackedMatMul:
def __init__(self):
if q_matmul.bits != k_matmul.bits or q_matmul.bits != v_matmul.bits:
raise ValueError("All MatMuls must have the same bits for packed MatMul.")
if q_matmul.group_size != k_matmul.group_size or q_matmul.group_size != v_matmul.group_size:
raise ValueError("All MatMuls must have the same group size for packed MatMul.")
self.qweight = torch.cat([q_matmul.qweight, k_matmul.qweight, v_matmul.qweight], dim=0)
self.scales = torch.cat([q_matmul.scales, k_matmul.scales, v_matmul.scales], dim=0)
self.qzeros = torch.cat([q_matmul.qzeros, k_matmul.qzeros, v_matmul.qzeros], dim=0)
Expand Down
Loading
Loading