Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,27 @@ def transform_model(
node = wc_params.node_with_weight
weight = self.get_weight(node, wc_params.weight_port_id, model, graph)
precomputed_compressed_weights = precomputed_compressed_weights or {}

dequantize_block_size = max(compression_config.group_size, 0) # 0 - is no block wise quantization
dequantize_axis = (
get_weight_quantization_axis(node, wc_params.weight_port_id) if dequantize_block_size <= 0 else 0
) # axis = 0 when blockwise

reduction_axes = wc_params.reduction_axes
if node.metatype == onnx_metatypes.ONNXGemmMetatype and opset_version < 21 and dequantize_block_size > 0:
attr_name = "transB" if wc_params.weight_port_id == 1 else "transA"
transpose = node.layer_attributes.node_attrs[attr_name]
weight = fns.transpose(weight) if transpose else weight
(axis,) = reduction_axes
axis = (axis + 1) % 2 if transpose else axis
reduction_axes = (axis,)

compressed_weight = compress_weight(
Tensor(weight),
wc_params.reduction_axes,
reduction_axes,
compression_config,
precomputed_compressed_weights.get(wc_params.weight_name),
)
dequantize_block_size = max(compression_config.group_size, 0) # 0 - is no block wise quantization
dequantize_axis = (
get_weight_quantization_axis(node, wc_params.weight_port_id) if dequantize_block_size <= 0 else 0
) # axis = 0 when blockwise

# NOTE: The `DequantizeLinear` operation supports the `block_size` attribute only starting from opset 21.
# For opsets earlier than 21, we use the `MatMulNBits` operation from ONNX Runtime contrib operators.
Expand Down Expand Up @@ -428,16 +439,21 @@ def _replace_matmul_with_matmulnbits(

original_matmul = self.name_to_node_map[weight_compression_parameters.node_with_weight.node_name]

activation_input_name = None
for input_name in original_matmul.input:
if input_name != weight_name:
activation_input_name = input_name
assert activation_input_name is not None, "Activation input name not found in original matmul node"
# Composing operation inputs: A, B, scales, zero_points[optional], g_idx[optional, deprecated], bias
bias_name = None
if weight_compression_parameters.node_with_weight.layer_attributes.has_bias():
bias_name = weight_compression_parameters.node_with_weight.layer_attributes.bias_attrs["name"]

activation_input_name = next(name for name in original_matmul.input if name not in [weight_name, bias_name])

# Create MatMulNBits
inputs = [activation_input_name, quantized_weight_name, scale_name]
if zero_point is not None:
inputs.append(zero_point_name)
if bias_name:
if zero_point is None:
inputs.append("")
inputs.append("") # g_idx
inputs.append(bias_name)

K, N = orig_weight.shape[0], orig_weight.shape[1]
matmul_n_bits = helper.make_node(
Expand Down
47 changes: 47 additions & 0 deletions tests/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,53 @@ def add_matmul(
)
return output

def add_gemm(
self,
input: str,
shape: tuple[int],
output: Optional[str] = None,
weight_data: Optional[np.ndarray] = None,
bias_data: Optional[np.ndarray] = None,
trans_b: int = 0,
) -> str:
i = len(self._nodes)

# weight params
w_name = f"W_{i}"
if weight_data is None:
w_values = np.random.rand(*shape).astype(np.float32)
else:
w_values = weight_data
w_initializer = onnx.helper.make_tensor(
name=w_name, data_type=onnx.TensorProto.FLOAT, dims=shape, vals=w_values.tobytes(), raw=True
)
self._initializers.append(w_initializer)

# bias params
b_name = f"B_{i}"
b_shape = shape[0] if trans_b else shape[1]
if bias_data is None:
b_values = np.random.rand(b_shape).astype(np.float32)
else:
b_values = bias_data
b_initializer = onnx.helper.make_tensor(
name=b_name, data_type=onnx.TensorProto.FLOAT, dims=(b_shape,), vals=b_values.tobytes(), raw=True
)
self._initializers.append(b_initializer)

output = f"Gemm_{i}_output" if output is None else output
self._nodes.append(
onnx.helper.make_node(
op_type="Gemm",
inputs=[input, w_name, b_name],
outputs=[output],
name=f"Gemm_{i}",
transA=0,
transB=trans_b,
)
)
return output

def add_mul(self, input_a: str, input_b: str, output: Optional[str] = None) -> str:
i = len(self._nodes)

Expand Down
41 changes: 41 additions & 0 deletions tests/onnx/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,47 @@ def test_matmulnbits():
assert np.allclose(output21, output19, rtol=rtol, atol=1e-6)


@pytest.mark.xfail(
version.parse(onnx.__version__) >= version.parse("1.18.0"),
reason="onnxruntime not support default IR for onnx==1.18.0",
)
@pytest.mark.parametrize("trans_b", [0, 1])
def test_matmulnbits_gemm(trans_b: int):
# Build the model with a single Gemm operation
np.random.seed(42)

w = np.random.rand(1280, 10).astype(np.float32)
if trans_b:
w = w.T
b = np.random.rand(10).astype(np.float32)

mb = ModelBuilder()
x = mb.add_input("input", (1, 1280))
x = mb.add_gemm(x, shape=w.shape, weight_data=w, bias_data=b, trans_b=trans_b)

mb.add_output(x, (1, 10))

model_opset19 = mb.build(opset_version=19)
model_opset21 = mb.build(opset_version=21)

rtol = 1e-5
if version.parse(onnxruntime.__version__) < version.parse("1.21.1"):
rtol = 1e-3

compressed_model_opset21 = compress_weights(model_opset21, mode=CompressWeightsMode.INT4_SYM, group_size=64)
compressed_model_opset19 = compress_weights(model_opset19, mode=CompressWeightsMode.INT4_SYM, group_size=64)

dummy_input = np.random.rand(1, 1280).astype(np.float32)

sess21 = InferenceSession(compressed_model_opset21.SerializeToString())
sess19 = InferenceSession(compressed_model_opset19.SerializeToString())

output21 = sess21.run(None, {"input": dummy_input})[0]
output19 = sess19.run(None, {"input": dummy_input})[0]

assert np.allclose(output21, output19, rtol=rtol, atol=1e-6)


class TestONNXTemplateWeightCompression(TemplateWeightCompression):
@staticmethod
def cast_to(x: np.ndarray, dtype: TensorDataType) -> np.ndarray:
Expand Down