diff --git a/examples/onnx_ptq/torch_quant_to_onnx.py b/examples/onnx_ptq/torch_quant_to_onnx.py index 418f1d7e..6246fe36 100644 --- a/examples/onnx_ptq/torch_quant_to_onnx.py +++ b/examples/onnx_ptq/torch_quant_to_onnx.py @@ -83,12 +83,12 @@ def forward_loop(model): return quantized_model -def get_model_input_shape(model_name): +def get_model_input_shape(model_name, batch_size): """Get the input shape from timm model configuration.""" model = timm.create_model(model_name, pretrained=True, num_classes=1000) data_config = timm.data.resolve_model_data_config(model) input_size = data_config["input_size"] - return (1, *tuple(input_size)) # Add batch dimension + return (batch_size, *tuple(input_size)) # Add batch dimension def main(): @@ -119,11 +119,17 @@ def main(): default=512, help="Number of images to use in calibration [1-512]", ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size for calibration and ONNX model export.", + ) args = parser.parse_args() # Get input shape from model config - input_shape = get_model_input_shape(args.timm_model_name) + input_shape = get_model_input_shape(args.timm_model_name, args.batch_size) # Create model and move to appropriate device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/examples/windows/onnx_ptq/genai_llm/requirements.txt b/examples/windows/onnx_ptq/genai_llm/requirements.txt index 10e3f9d7..dd9b8008 100644 --- a/examples/windows/onnx_ptq/genai_llm/requirements.txt +++ b/examples/windows/onnx_ptq/genai_llm/requirements.txt @@ -1,3 +1,4 @@ datasets>=2.14.5 +onnx==1.18.0 torch==2.6.0 transformers==4.49.0 diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index e31b2e48..c2216f8d 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -23,7 +23,6 @@ import onnx_graphsurgeon as gs import torch from onnx import numpy_helper -from onnx.reference.custom_element_types import float8e4m3fn from modelopt.onnx import utils from modelopt.onnx.logging_config import logger @@ -50,6 +49,7 @@ onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, "Float": onnx.TensorProto.FLOAT, + "Float4": onnx.TensorProto.FLOAT4E2M1, "Float8": onnx.TensorProto.FLOAT8E4M3FN, "Half": onnx.TensorProto.FLOAT16, "INT8": onnx.TensorProto.INT8, @@ -592,7 +592,7 @@ def _convert_weight( zp_array = zp_array.reshape(*reshape_dims) # Convert to INT8/FP8 - if zp_array.dtype == float8e4m3fn: + if zp_array.dtype == onnx_dtype_map["Float8"]: scaled = np.asarray(weight_array / scale_array) + zp_array else: scaled = np.asarray((weight_array / scale_array).round()) @@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray: if torch.cuda.is_available(): array_f32_t = array_f32_t.cuda() array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8) - array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")])) + array_f8 = array_f8_t.cpu().numpy().astype(np.uint8) return array_f8 def _cast_fp4(array: np.ndarray) -> np.ndarray: - """Cast a numpy array to FLOAT4E2M1 using PyTorch.""" + """Cast a numpy array to FLOAT4E2M1 using PyTorch. + + Note: The first dimension of the array must be divisible by 2 + as two FP4 values are packed into a single byte. + """ array_f32_t = torch.from_numpy(array) + array_f32_t_shape = array_f32_t.shape + assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2" + array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:]) if torch.cuda.is_available(): array_f32_t = array_f32_t.cuda() array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t) - array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")])) + array_f4_t = array_f4_t.flatten() + array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape) + array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8) return array_f4 @@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto: scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node) # Create and update new weight tensor - if zp_array.dtype == float8e4m3fn: + if zp_array.dtype == onnx_dtype_map["Float8"]: new_weight = _create_fp8_tensor(scaled, weight_name) logger.debug(f"Converted {weight_name} to FP8") else: @@ -920,6 +929,10 @@ def quantize_weights_to_int4( assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}" reshape_node_output = reshape_node.output[0] + # Remove constant node from reshape node + shape_constant_name = next(input for input in reshape_node.input if "Constant" in input) + nodes_to_remove.append(tensor_producer_map[shape_constant_name].name) + # Get the shape of the output of the reshape node reshape_output_value_info = value_info_map.get(reshape_node_output) if reshape_output_value_info is not None: @@ -937,12 +950,17 @@ def quantize_weights_to_int4( scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size] scale = scale.reshape(scale_shape) reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input] - # reshape_node.input = [] assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}" + # Remove unnecessary Cast node + cast_node = reshape_child_nodes[0] + assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" + nodes_to_remove.append(cast_node.name) + cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input] + # Transpose weights and scales if present - if reshape_child_nodes[0].op_type == "Transpose": - transpose_node = reshape_child_nodes[0] + if cast_child_nodes[0].op_type == "Transpose": + transpose_node = cast_child_nodes[0] nodes_to_remove.append(transpose_node.name) assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}" perm = None @@ -959,7 +977,7 @@ def quantize_weights_to_int4( ) matmul_node = transpose_child_nodes[0] else: - matmul_node = reshape_child_nodes[0] + matmul_node = cast_child_nodes[0] assert matmul_node.op_type in ["MatMul", "Gemm"], ( f"Expected MatMul or Gemm node for {node.name}" ) @@ -990,6 +1008,21 @@ def quantize_weights_to_int4( initializer_map[weight_name].CopyFrom(weights_int4_onnx) logger.debug(f"Converted {weight_name} to INT4 precision") + def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool: + has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input) + return node.op_type == "Mul" and has_pqs_input + + # Remove unnecessay Cast after Pre-quant scale + for node in graph.node: + if is_pre_quant_scale_node(node): + pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input] + assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}" + cast_node = pqs_child_nodes[0] + assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}" + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) + # Remove transpose and reshape nodes new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] graph.node.clear() @@ -1004,7 +1037,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool: for node in graph.node: if node.op_type == "Cast": # Skip Cast nodes that are part of normalization layers and outputs - if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast": + if "norm/Cast" in node.name and is_fp32_cast(node): continue for attr in node.attribute: if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: @@ -1099,7 +1132,13 @@ def quantize_weights_to_mxfp8( # Expand block array so that it can be broadcasted with weight se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) - weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name) + weights_e4m3 = onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=[*scaled_weight.shape], + vals=_cast_fp8(scaled_weight).tobytes(), + raw=True, + ) initializer_map[weight_name].CopyFrom(weights_e4m3) logger.debug(f"Converted {weight_name} to MXFP8") @@ -1181,11 +1220,24 @@ def _add_input_value_info(graph, tensor_proto): sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale" # Create TensorProto for initializers - w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name) + w_f4_proto = onnx.helper.make_tensor( + name=w_f4_name, + data_type=onnx_dtype_map["Float4"], + dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]], + vals=w_f4.tobytes(), + raw=True, + ) sw_f32_per_tensor_proto = onnx.numpy_helper.from_array( sw_f32_per_tensor, sw_f32_per_tensor_name ) sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name) + sw_f8_per_block_proto = onnx.helper.make_tensor( + name=sw_f8_per_block_name, + data_type=onnx_dtype_map["Float8"], + dims=[*sw_f8_per_block.shape], + vals=sw_f8_per_block.tobytes(), + raw=True, + ) # Add ValueInfo for the initializers if not present _add_input_value_info(graph, w_f4_proto) diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index e18a9d20..ac617f72 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -485,8 +485,8 @@ def get_onnx_bytes_and_metadata( except StopIteration: param_dtype = torch.float32 if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_mxfp8_quantized(model): - assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet" + if is_mxfp8_quantized(model) or is_int4_quantized(model): + assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, keep_io_types=False, diff --git a/setup.py b/setup.py index 45361b1f..46abccac 100644 --- a/setup.py +++ b/setup.py @@ -47,8 +47,8 @@ "cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'", "ml_dtypes", # for bfloat16 conversion "onnx-graphsurgeon", - "onnx~=1.18.0", - "onnxconverter-common", + "onnx~=1.19.0", + "onnxconverter-common~=1.16.0", "onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'", "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", diff --git a/tests/_test_utils/import_helper.py b/tests/_test_utils/import_helper.py index 03d3c8f2..43f97493 100644 --- a/tests/_test_utils/import_helper.py +++ b/tests/_test_utils/import_helper.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata import shutil import pytest +from packaging import version def skip_if_no_tensorrt(): @@ -73,3 +75,18 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool if mamba_required and not has_mamba: pytest.skip("Mamba required for Megatron test", allow_module_level=True) + + +def skip_if_onnx_version_above_1_18(): + package_name = "onnx" + required_version = "1.18.0" + + try: + installed_version = importlib.metadata.version(package_name) + except importlib.metadata.PackageNotFoundError: + pytest.skip(f"{package_name} is not installed") + + if version.parse(installed_version) > version.parse(required_version): + pytest.skip( + f"{package_name} version {installed_version} is less than required {required_version}" + ) diff --git a/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py b/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py index 826d9015..e0ff1240 100644 --- a/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py +++ b/tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py @@ -20,7 +20,7 @@ from functools import partial import torch -from _test_utils.import_helper import skip_if_no_libcudnn +from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18 from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init from _test_utils.torch_quantization.quantize_common import get_awq_config @@ -40,6 +40,8 @@ def test_int4_awq(tmp_path): + skip_if_onnx_version_above_1_18() + def _forward_loop(model, dataloader): """Forward loop for calibration.""" for data in dataloader: @@ -114,6 +116,7 @@ def _forward_loop(model, dataloader): def test_int4_awq_cuda(tmp_path): + skip_if_onnx_version_above_1_18() skip_if_no_libcudnn() block_size = 128 diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index 089970d3..ca7d1518 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -33,7 +33,6 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa # Create reshape shape tensor reshape_shape = np.array([16, 16], dtype=np.int64) - reshape_shape_tensor = numpy_helper.from_array(reshape_shape, "reshape_shape") # Create input tensor for MatMul input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [None, 2]) @@ -53,16 +52,32 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa "DequantizeLinear", inputs=dq_inputs, outputs=["dq_output"], name="weight_dq" ) + reshape_constant = helper.make_node( + "Constant", + inputs=[], + outputs=["reshape_shape_Constant"], + value=numpy_helper.from_array(reshape_shape), + name="reshape_constant", + ) + reshape_node = helper.make_node( "Reshape", - inputs=["dq_output", "reshape_shape"], + inputs=["dq_output", "reshape_shape_Constant"], outputs=["reshape_output"], name="weight_reshape", ) + cast_node = helper.make_node( + "Cast", + inputs=["reshape_output"], + outputs=["cast_output"], + to=TensorProto.FLOAT, + name="weight_cast", + ) + transpose_node = helper.make_node( "Transpose", - inputs=["reshape_output"], + inputs=["cast_output"], outputs=["transpose_output"], perm=[1, 0], name="weight_transpose", @@ -78,7 +93,7 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa ) # Create graph - nodes = [dq_node, reshape_node, transpose_node, matmul_node] + nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node] if constant_scale: nodes.append(scale_constant) graph = helper.make_graph( @@ -86,7 +101,7 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa name="test_graph", inputs=[input_tensor], outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 16])], - initializer=[weight_tensor, scale_tensor, reshape_shape_tensor], + initializer=[weight_tensor, scale_tensor], value_info=[reshape_output_info], ) @@ -234,7 +249,7 @@ def test_cast_node_conversion(self): if node.op_type == "Cast": to_attr = next(attr for attr in node.attribute if attr.name == "to") - if "norm/Cast" in node.name or node.name == "/Cast": + if "norm/Cast" in node.name: # These should remain as float32 assert to_attr.i == TensorProto.FLOAT else: @@ -297,39 +312,39 @@ def test_cast_fp8(self, input_array, expected_array): [ # Basic positive values ( - np.array([0.0, 0.5, 1.0], dtype=np.float32), - np.array([0, 1, 2], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[0.0, 0.5], [1.0, 1.5]], dtype=np.float32), + np.array([[16, 50]], dtype=np.uint8), ), # Basic negative values ( - np.array([-0.5, -1.0, -1.5], dtype=np.float32), - np.array([9, 10, 11], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[-0.5, -1.0], [-1.5, 1.75]], dtype=np.float32), + np.array([[169, 75]], dtype=np.uint8), ), # Boundary values with rounding ( - np.array([0.75, 1.75, 3.5], dtype=np.float32), - np.array([2, 4, 6], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[0.0, 0.75], [1.75, 3.5]], dtype=np.float32), + np.array([[32, 100]], dtype=np.uint8), ), # Large values (saturate to max) ( - np.array([10.0, -10.0], dtype=np.float32), - np.array([7, 15], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[10.0], [-10.0]], dtype=np.float32), + np.array([[247]], dtype=np.uint8), ), # Very small values (map to zero) ( - np.array([0.1, -0.1], dtype=np.float32), - np.array([0, 8], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[0.1], [-0.1]], dtype=np.float32), + np.array([[128]], dtype=np.uint8), ), # Zero and negative zero ( - np.array([0.0, -0.0], dtype=np.float32), - np.array([0, 0], dtype=(np.uint8, [("float4e2m1", "u1")])), + np.array([[0.0], [-0.0]], dtype=np.float32), + np.array([[0]], dtype=np.uint8), ), ], ) def test_cast_fp4(self, input_array, expected_array): """Test FP4 casting functionality.""" result = _cast_fp4(input_array) - assert result.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")])) + assert result.dtype == np.dtype(np.uint8) assert result.shape == expected_array.shape assert np.all(result == expected_array)