Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,27 @@ def transform_model(
node = wc_params.node_with_weight
weight = self.get_weight(node, wc_params.weight_port_id, model, graph)
precomputed_compressed_weights = precomputed_compressed_weights or {}

dequantize_block_size = max(compression_config.group_size, 0) # 0 - is no block wise quantization
dequantize_axis = (
get_weight_quantization_axis(node, wc_params.weight_port_id) if dequantize_block_size <= 0 else 0
) # axis = 0 when blockwise

reduction_axes = wc_params.reduction_axes
if node.metatype == onnx_metatypes.ONNXGemmMetatype and opset_version < 21 and dequantize_block_size > 0:
attr_name = "transB" if wc_params.weight_port_id == 1 else "transA"
transpose = node.layer_attributes.node_attrs[attr_name]
weight = fns.transpose(weight) if transpose else weight
(axis,) = reduction_axes
axis = (axis + 1) % 2 if transpose else axis
reduction_axes = (axis,)

compressed_weight = compress_weight(
Tensor(weight),
wc_params.reduction_axes,
reduction_axes,
compression_config,
precomputed_compressed_weights.get(wc_params.weight_name),
)
dequantize_block_size = max(compression_config.group_size, 0) # 0 - is no block wise quantization
dequantize_axis = (
get_weight_quantization_axis(node, wc_params.weight_port_id) if dequantize_block_size <= 0 else 0
) # axis = 0 when blockwise

# NOTE: The `DequantizeLinear` operation supports the `block_size` attribute only starting from opset 21.
# For opsets earlier than 21, we use the `MatMulNBits` operation from ONNX Runtime contrib operators.
Expand Down Expand Up @@ -428,16 +439,21 @@ def _replace_matmul_with_matmulnbits(

original_matmul = self.name_to_node_map[weight_compression_parameters.node_with_weight.node_name]

activation_input_name = None
for input_name in original_matmul.input:
if input_name != weight_name:
activation_input_name = input_name
assert activation_input_name is not None, "Activation input name not found in original matmul node"
# Composing operation inputs: A, B, scales, zero_points[optional], g_idx[optional, deprecated], bias
bias_name = None
if weight_compression_parameters.node_with_weight.layer_attributes.has_bias():
bias_name = weight_compression_parameters.node_with_weight.layer_attributes.bias_attrs["name"]

activation_input_name = next(name for name in original_matmul.input if name not in [weight_name, bias_name])

# Create MatMulNBits
inputs = [activation_input_name, quantized_weight_name, scale_name]
if zero_point is not None:
inputs.append(zero_point_name)
if bias_name:
if zero_point is None:
inputs.append("")
inputs.append("") # g_idx
inputs.append(bias_name)

K, N = orig_weight.shape[0], orig_weight.shape[1]
matmul_n_bits = helper.make_node(
Expand Down