diff --git a/examples/onnx_ptq/README.md b/examples/onnx_ptq/README.md index 47058d361..483dde359 100644 --- a/examples/onnx_ptq/README.md +++ b/examples/onnx_ptq/README.md @@ -26,6 +26,13 @@ Model Optimizer enables highly performant quantization formats including NVFP4, Please use the TensorRT docker image (e.g., `nvcr.io/nvidia/tensorrt:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html) for more information. +Set the following environment variables inside the TensorRT docker. + +```bash +export CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ +export LD_LIBRARY_PATH="${CUDNN_LIB_DIR}:${LD_LIBRARY_PATH}" +``` + Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install example-specific dependencies. ### Local Installation diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 5083df34d..a4b5d67fc 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -179,6 +179,7 @@ def convert_to_f16( sanitizer.find_custom_nodes() sanitizer.convert_opset() sanitizer.ensure_graph_name_exists() + sanitizer.convert_fp64_to_fp32() model = sanitizer.model # Setup internal mappings diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 83e9df89a..49b78c749 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -65,6 +65,27 @@ def sanitize(self) -> None: self.replace_custom_domain_nodes() self.cleanup_model() self.set_ir_version(self.max_ir_version) + self.convert_fp64_to_fp32() + + def convert_fp64_to_fp32(self) -> None: + """Convert FP64 initializers, I/O types, and specific nodes to FP32.""" + modified = False + + # Convert initializers + if self._convert_fp64_initializers(): + modified = True + + # Convert input/output types + if self._convert_fp64_io_types(): + modified = True + + # Convert specific node types: Cast, ConstantOfShape, Constant + if self._convert_fp64_nodes(): + modified = True + + if modified: + logger.info("Converted FP64 initializers, I/O types, and nodes to FP32") + self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) def find_custom_nodes(self) -> None: """Find custom nodes in the model. @@ -405,6 +426,85 @@ def _get_initializer_value(self, name: str, return_array: bool = False) -> np.nd return value if return_array else value.item() return None + def _convert_fp64_initializers(self) -> bool: + """Convert FP64 initializers to FP32. + + Returns: + bool: True if any initializers were modified, False otherwise. + """ + modified = False + + for initializer in self.model.graph.initializer: + if initializer.data_type == onnx.TensorProto.DOUBLE: + # Convert the data to FP32 + fp64_data = numpy_helper.to_array(initializer) + fp32_data = fp64_data.astype(np.float32) + + # Create new initializer with FP32 data + new_initializer = numpy_helper.from_array(fp32_data, name=initializer.name) + + # Replace the old initializer + initializer.CopyFrom(new_initializer) + modified = True + logger.debug(f"Converted initializer {initializer.name} from FP64 to FP32") + + return modified + + def _convert_fp64_io_types(self) -> bool: + """Convert FP64 input/output types to FP32. + + Returns: + bool: True if any I/O types were modified, False otherwise. + """ + modified = False + + def convert_tensor_list(tensors, tensor_type): + nonlocal modified + for tensor in tensors: + if tensor.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE: + tensor.type.tensor_type.elem_type = onnx.TensorProto.FLOAT + modified = True + logger.debug(f"Converted {tensor_type} {tensor.name} from FP64 to FP32") + + convert_tensor_list(self.model.graph.input, "input") + convert_tensor_list(self.model.graph.output, "output") + convert_tensor_list(self.model.graph.value_info, "value_info") + + return modified + + def _convert_fp64_nodes(self) -> bool: + """Convert specific node types from FP64 to FP32. + + Handles Cast, ConstantOfShape, and Constant nodes that use FP64. + + Returns: + bool: True if any nodes were modified, False otherwise. + """ + modified = False + + for node in self.model.graph.node: + if node.op_type == "Cast": + # Check if casting to FP64, change to FP32 + for attr in node.attribute: + if attr.name == "to" and attr.i == onnx.TensorProto.DOUBLE: + attr.i = onnx.TensorProto.FLOAT + modified = True + logger.debug(f"Converted Cast node {node.name} from FP64 to FP32") + + elif node.op_type in ["ConstantOfShape", "Constant"]: + # Check if the value attribute uses FP64 + for attr in node.attribute: + if attr.name == "value" and attr.t.data_type == onnx.TensorProto.DOUBLE: + # Convert the tensor value to FP32 + fp64_data = numpy_helper.to_array(attr.t) + fp32_data = fp64_data.astype(np.float32) + new_tensor = numpy_helper.from_array(fp32_data) + attr.t.CopyFrom(new_tensor) + modified = True + logger.debug(f"Converted {node.op_type} node {node.name} from FP64 to FP32") + + return modified + def cleanup_model(self) -> None: """Use GraphSurgeon to cleanup unused nodes, tensors and initializers.""" gs_graph = gs.import_onnx(self.model) diff --git a/tests/unit/onnx/autocast/test_graphsanitizer.py b/tests/unit/onnx/autocast/test_graphsanitizer.py index e1c447da4..cb487b56b 100644 --- a/tests/unit/onnx/autocast/test_graphsanitizer.py +++ b/tests/unit/onnx/autocast/test_graphsanitizer.py @@ -183,3 +183,227 @@ def test_invalid_layernorm_pattern(): # Verify no LayerNorm transformation occurred assert not any(node.op_type == "LayerNormalization" for node in sanitizer.model.graph.node) + + +def test_convert_fp64_initializers(): + """Test conversion of FP64 initializers to FP32.""" + # Create a model with FP64 initializers + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Create FP64 initializers + fp64_weights = np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], dtype=np.float64) + fp64_bias = np.array([0.1, 0.2, 0.3], dtype=np.float64) + fp32_weights = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + initializers = [ + numpy_helper.from_array(fp64_weights, name="fp64_weights"), + numpy_helper.from_array(fp64_bias, name="fp64_bias"), + numpy_helper.from_array(fp32_weights, name="fp32_weights"), + ] + + # Verify the FP64 initializers have correct data type + assert initializers[0].data_type == TensorProto.DOUBLE + assert initializers[1].data_type == TensorProto.DOUBLE + assert initializers[2].data_type == TensorProto.FLOAT + + add_node = helper.make_node("Add", ["X", "fp64_weights"], ["Y"]) + + graph = helper.make_graph( + nodes=[add_node], name="fp64_test", inputs=[x], outputs=[y], initializer=initializers + ) + + model = helper.make_model(graph) + sanitizer = GraphSanitizer(model) + + # Test the conversion + result = sanitizer._convert_fp64_initializers() + assert result is True + + # Verify all initializers are now FP32 + for init in sanitizer.model.graph.initializer: + if init.name in ["fp64_weights", "fp64_bias"]: + assert init.data_type == TensorProto.FLOAT + # Verify data integrity + converted_data = numpy_helper.to_array(init) + assert converted_data.dtype == np.float32 + elif init.name == "fp32_weights": + assert init.data_type == TensorProto.FLOAT + + +def test_convert_fp64_io_types(): + """Test conversion of FP64 input/output types to FP32.""" + # Create inputs and outputs with FP64 types + x_fp64 = helper.make_tensor_value_info("X_fp64", TensorProto.DOUBLE, [2, 3]) + y_fp64 = helper.make_tensor_value_info("Y_fp64", TensorProto.DOUBLE, [2, 3]) + x_fp32 = helper.make_tensor_value_info("X_fp32", TensorProto.FLOAT, [2, 3]) + + # Create value_info with FP64 type + value_info_fp64 = helper.make_tensor_value_info("intermediate", TensorProto.DOUBLE, [2, 3]) + value_info_fp32 = helper.make_tensor_value_info("intermediate2", TensorProto.FLOAT, [2, 3]) + + add_node = helper.make_node("Add", ["X_fp64", "X_fp32"], ["Y_fp64"]) + + graph = helper.make_graph( + nodes=[add_node], + name="fp64_io_test", + inputs=[x_fp64, x_fp32], + outputs=[y_fp64], + value_info=[value_info_fp64, value_info_fp32], + ) + + model = helper.make_model(graph) + sanitizer = GraphSanitizer(model) + + # Test the conversion + result = sanitizer._convert_fp64_io_types() + assert result is True + + # Verify inputs are converted + assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT + assert sanitizer.model.graph.input[1].type.tensor_type.elem_type == TensorProto.FLOAT + + # Verify outputs are converted + assert sanitizer.model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT + + # Verify value_info are converted + for vi in sanitizer.model.graph.value_info: + assert vi.type.tensor_type.elem_type == TensorProto.FLOAT + + +def test_convert_fp64_nodes(): + """Test conversion of specific node types from FP64 to FP32.""" + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # Create FP64 constant tensor for ConstantOfShape and Constant nodes + fp64_value = numpy_helper.from_array(np.array([1.5], dtype=np.float64)) + fp64_shape_value = numpy_helper.from_array(np.array([2.5], dtype=np.float64)) + + # Create nodes that use FP64 + cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE) + constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_value) + constant_shape_node = helper.make_node( + "ConstantOfShape", ["shape"], ["shape_out"], value=fp64_shape_value + ) + add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"]) + + # Shape input for ConstantOfShape + shape_init = numpy_helper.from_array(np.array([2, 3], dtype=np.int64), name="shape") + + graph = helper.make_graph( + nodes=[cast_node, constant_node, constant_shape_node, add_node], + name="fp64_nodes_test", + inputs=[x], + outputs=[y], + initializer=[shape_init], + ) + + model = helper.make_model(graph) + sanitizer = GraphSanitizer(model) + + # Test the conversion + result = sanitizer._convert_fp64_nodes() + assert result is True + + # Verify Cast node is converted + cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"] + assert len(cast_nodes) == 1 + cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to") + assert cast_attr.i == TensorProto.FLOAT + + # Verify Constant node is converted + constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"] + assert len(constant_nodes) == 1 + const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value") + assert const_attr.t.data_type == TensorProto.FLOAT + + # Verify ConstantOfShape node is converted + const_shape_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "ConstantOfShape"] + assert len(const_shape_nodes) == 1 + shape_attr = next(attr for attr in const_shape_nodes[0].attribute if attr.name == "value") + assert shape_attr.t.data_type == TensorProto.FLOAT + + +def test_convert_fp64_to_fp32_integration(): + """Test the main convert_fp64_to_fp32 method with mixed FP64/FP32 content.""" + # Create a model with mixed FP64 and FP32 content + x_fp64 = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [2, 3]) + y_fp32 = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + # FP64 initializer + fp64_weights = numpy_helper.from_array( + np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float64), name="weights" + ) + + # FP64 constant value + fp64_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float64)) + + # Create nodes + cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE) + constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_const_value) + add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"]) + + graph = helper.make_graph( + nodes=[cast_node, constant_node, add_node], + name="mixed_fp64_test", + inputs=[x_fp64], + outputs=[y_fp32], + initializer=[fp64_weights], + ) + + model = helper.make_model(graph) + sanitizer = GraphSanitizer(model) + + # Test the main conversion method + sanitizer.convert_fp64_to_fp32() + + # Verify all FP64 content has been converted + # Check input types + assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT + + # Check initializers + for init in sanitizer.model.graph.initializer: + assert init.data_type == TensorProto.FLOAT + + # Check Cast node + cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"] + cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to") + assert cast_attr.i == TensorProto.FLOAT + + # Check Constant node + constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"] + const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value") + assert const_attr.t.data_type == TensorProto.FLOAT + + +def test_convert_fp64_no_changes_needed(): + """Test that conversion methods return False when no FP64 content exists.""" + # Create a model with only FP32 content + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + fp32_weights = numpy_helper.from_array( + np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), name="weights" + ) + fp32_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float32)) + + cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.FLOAT) + constant_node = helper.make_node("Constant", [], ["const_out"], value=fp32_const_value) + add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"]) + + graph = helper.make_graph( + nodes=[cast_node, constant_node, add_node], + name="fp32_only_test", + inputs=[x], + outputs=[y], + initializer=[fp32_weights], + ) + + model = helper.make_model(graph) + sanitizer = GraphSanitizer(model) + + # Test that no conversions are needed + assert sanitizer._convert_fp64_initializers() is False + assert sanitizer._convert_fp64_io_types() is False + assert sanitizer._convert_fp64_nodes() is False