Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 108 additions & 6 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
ir.DataType.BFLOAT16,
ir.DataType.INT4,
ir.DataType.UINT4,
ir.DataType.INT2,
ir.DataType.UINT2,
]
| int,
ep: str,
Expand Down Expand Up @@ -292,14 +294,23 @@ def __init__(

# Quantization-specific variables (INT4, INT8, etc.)
int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default"))
int2_algo_config = self.make_int2_algo_config(extra_options.get("int2_algo_config", "default"))
self.quant_attrs = {
"int2": {
"accuracy_level": int(extra_options.get("int2_accuracy_level", 0)), # Default is 0 for non-QDQ formats, default is 4 for QDQ formats
"block_size": int(extra_options.get("int2_block_size", 32)),
"is_symmetric": extra_options.get("int2_is_symmetric", True),
"op_types_to_quantize": extra_options.get("int2_op_types_to_quantize", ("MatMul", )),
"nodes_to_exclude": extra_options.get("int2_nodes_to_exclude", []),
"algo_config": int2_algo_config,
},
"int4": {
"accuracy_level": int(extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0)),
"block_size": int(extra_options.get("int4_block_size", 32)),
"is_symmetric": extra_options.get("int4_is_symmetric", True),
"op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul", )),
"nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []),
"algo_config": int4_algo_config,
"algo_config": int2_algo_config,
},
"use_qdq": extra_options.get("use_qdq", False),
}
Expand Down Expand Up @@ -445,6 +456,9 @@ def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)

def make_int2_algo_config(self, quant_method: str):
return None

def make_int4_algo_config(self, quant_method: str):
int4_algo_config = None
if quant_method == "rtn":
Expand Down Expand Up @@ -475,6 +489,7 @@ def make_int4_algo_config(self, quant_method: str):
def to_int4(self) -> ir.Model:
quant = MatMulNBitsQuantizer(
model=ir.to_proto(self.model),
bits=4,
block_size=self.quant_attrs["int4"]["block_size"],
is_symmetric=self.quant_attrs["int4"]["is_symmetric"],
accuracy_level=self.quant_attrs["int4"]["accuracy_level"],
Expand All @@ -485,12 +500,29 @@ def to_int4(self) -> ir.Model:
)
quant.process()
return ir.from_proto(quant.model.model)

def to_int2(self) -> ir.Model:
quant = MatMulNBitsQuantizer(
model = ir.to_proto(self.model),
bits = 2,
block_size = self.quant_attrs["int2"]["block_size"],
is_symmetric = self.quant_attrs["int2"]["is_symmetric"],
accuracy_level = self.quant_attrs["int2"]["accuracy_level"],
nodes_to_exclude = self.quant_attrs["int2"]["nodes_to_exclude"],
quant_format = QuantFormat.QDQ if self.quant_attrs["use_qdq"] else QuantFormat.QOperator,
op_types_to_quantize = self.quant_attrs["int2"]["op_types_to_quantize"],
algo_config = self.quant_attrs["int2"]["algo_config"],
)
quant.process()
return ir.from_proto(quant.model.model)

def save_model(self, out_dir):
print(f"Saving ONNX model in {out_dir}")

already_quantized_in_qdq_format = self.quant_type is not None and self.quant_attrs["use_qdq"] # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
if self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and not already_quantized_in_qdq_format:
if self.onnx_dtype in {ir.DataType.INT2, ir.DataType.UINT2} and not already_quantized_in_qdq_format:
model = self.to_int2()
elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and not already_quantized_in_qdq_format:
model = self.to_int4()
else:
model = self.model
Expand Down Expand Up @@ -799,9 +831,14 @@ def make_matmul(self, matmul, basename, root_input, **kwargs):
def make_matmul_op(self, matmul, basename, root_input, **kwargs):
if self.onnx_dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16, ir.DataType.FLOAT}:
return self.make_matmul_float(matmul, basename, root_input, **kwargs)
elif self.onnx_dtype in {ir.DataType.INT2, ir.DataType.UINT2}:
if self.quant_attrs["use_qdq"]:
return self.make_matmul_nbits_qdq(matmul, basename, root_input, **kwargs)
else:
return self.make_matmul_int2(matmul, basename, root_input, **kwargs)
elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
if self.quant_attrs["use_qdq"]:
return self.make_matmul_int4_qdq(matmul, basename, root_input, **kwargs)
return self.make_matmul_nbits_qdq(matmul, basename, root_input, **kwargs)
else:
return self.make_matmul_int4(matmul, basename, root_input, **kwargs)
else:
Expand All @@ -818,6 +855,44 @@ def make_matmul_float(self, matmul, name, root_input, **kwargs):

return name

#TODO: Instead of having "make_matmul_int2", "make_matmul_int4" replace with "make_matmul_nbits" and pass the bits as an argument.
def make_matmul_int2(self, matmul, basename, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_matmul_float(matmul, basename, root_input, **kwargs)

name = f"{basename}NBits"

# Input weights are quantized, save quantized MatMul numpy weights for onnx model
weight = name[1:].replace("/", ".") + ".qweight"
self.make_external_tensor(matmul.qweight, weight)
scales = name[1:].replace("/", ".") + ".scales"
self.make_external_tensor(matmul.scales, scales, to=self.io_dtype)

inputs = [root_input, weight, scales]

if hasattr(matmul, "qzeros") and matmul.qzeros is not None:
zeros = name[1:].replace("/", ".") + ".qzeros"
self.make_external_tensor(matmul.qzeros, zeros)
inputs.append(zeros)

if hasattr(matmul, "g_idx") and matmul.g_idx is not None:
g_idx = name[1:].replace("/", ".") + ".g_idx"
self.make_external_tensor(matmul.g_idx, g_idx, to=ir.DataType.INT32)
inputs.append(g_idx)

output = "logits" if kwargs.get("logits", False) else f"{name}/output_0"
self.make_node(
"MatMulNBits", inputs=inputs, outputs=[output], name=name, domain="com.microsoft",
accuracy_level=self.quant_attrs["int2"]["accuracy_level"],
bits=matmul.bits, block_size=matmul.group_size, K=matmul.in_features, N=matmul.out_features,
)
self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', matmul.out_features])

return name

def make_matmul_int4(self, matmul, basename, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
Expand Down Expand Up @@ -895,7 +970,7 @@ def make_dequantize_linear(self, dequantize_name, quantized_op):

return dequantize_output

def make_matmul_int4_qdq(self, matmul, matmul_name, root_input, **kwargs):
def make_matmul_nbits_qdq(self, matmul, matmul_name, root_input, **kwargs):
if not hasattr(matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
Expand Down Expand Up @@ -962,6 +1037,8 @@ def make_matmul_lora(self, matmul, basename, root_input, **kwargs):
def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs):
if self.onnx_dtype in {ir.DataType.FLOAT, ir.DataType.FLOAT16, ir.DataType.BFLOAT16}:
return self.make_packed_matmul_float(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
elif self.onnx_dtype in {ir.DataType.INT2, ir.DataType.UINT2}:
return self.make_packed_matmul_int2(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
elif self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4}:
return self.make_packed_matmul_int4(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)
else:
Expand All @@ -984,6 +1061,30 @@ def __init__(self):

return new_name

def make_packed_matmul_int2(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs):
if not hasattr(q_matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
# print(f"Quantizing to {self.onnx_dtype} on-the-fly is not currently supported.")
# print(f"Saving as {self.io_dtype} on-the-fly and quantizing to {self.onnx_dtype} at the end.")
return self.make_packed_matmul_float(q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs)

# Create dummy PackedMatMul class
class PackedMatMul:
def __init__(self):
self.qweight = torch.cat([q_matmul.qweight, k_matmul.qweight, v_matmul.qweight], dim=0)
self.scales = torch.cat([q_matmul.scales, k_matmul.scales, v_matmul.scales], dim=0)
self.qzeros = torch.cat([q_matmul.qzeros, k_matmul.qzeros, v_matmul.qzeros], dim=0)
self.g_idx = q_matmul.g_idx

self.in_features = q_matmul.in_features
self.out_features = q_matmul.out_features + k_matmul.out_features + v_matmul.out_features
self.bits = q_matmul.bits
self.group_size = q_matmul.group_size
matmul = PackedMatMul()
new_name = self.make_matmul_int2(matmul, basename, root_input, **kwargs)

return new_name

def make_packed_matmul_int4(self, q_matmul, k_matmul, v_matmul, basename, root_input, **kwargs):
if not hasattr(q_matmul, "qweight"):
# TODO: quantize weights, then save new MatMul weights for onnx model
Expand Down Expand Up @@ -3702,7 +3803,7 @@ def parse_hf_token(hf_token):


def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType:
if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") or extra_options.get("use_webgpu_fp32", False):
if precision in {"int8", "fp32"} or (precision in ("int2", "int4") and execution_provider == "cpu") or extra_options.get("use_webgpu_fp32", False):
# FP32 precision
return ir.DataType.FLOAT

Expand All @@ -3724,6 +3825,7 @@ def set_onnx_dtype(precision: str, extra_options: dict[str, Any]) -> ir.DataType
"fp32": ir.DataType.FLOAT,
"fp16": ir.DataType.FLOAT16,
"bf16": ir.DataType.BFLOAT16,
"int2": ir.DataType.INT2,
}[precision]


Expand Down Expand Up @@ -3869,7 +3971,7 @@ def get_args():
"-p",
"--precision",
required=True,
choices=["int4", "bf16", "fp16", "fp32"],
choices=["int2","int4", "bf16", "fp16", "fp32"],
help="Precision of model",
)

Expand Down
Loading