Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/onnx_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ Model Optimizer enables highly performant quantization formats including NVFP4,

Please use the TensorRT docker image (e.g., `nvcr.io/nvidia/tensorrt:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html) for more information.

Set the following environment variables inside the TensorRT docker.

```bash
export CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
export LD_LIBRARY_PATH="${CUDNN_LIB_DIR}:${LD_LIBRARY_PATH}"
```
Comment on lines +29 to +34
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we do in the CICD: https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/_installation_for_Linux.html#environment-setup

If that is sufficient, we can also just link this doc here. Else we can update that file too

Copy link
Contributor Author

@ajrasane ajrasane Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The steps mentioned above are a bit different for the TRT container.
There is no /usr/local/tensorrt/ in the TRT container
trtexec is stored under /opt/tensorrt/bin, which is already in PATH
And adding /usr/lib/x86_64-linux-gnu is sufficient for getting libcudnn.so, which is required for the TRTExecutionProvider.
I would suggest we keep this as it is as it will be more convenient for the user to run these steps as part of the setup right after they runs their container.
The steps you shared above should also work fine for the CI, but they are not required for the TRT container.


Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install example-specific dependencies.

### Local Installation
Expand Down
1 change: 1 addition & 0 deletions modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def convert_to_f16(
sanitizer.find_custom_nodes()
sanitizer.convert_opset()
sanitizer.ensure_graph_name_exists()
sanitizer.convert_fp64_to_fp32()
model = sanitizer.model

# Setup internal mappings
Expand Down
100 changes: 100 additions & 0 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ def sanitize(self) -> None:
self.replace_custom_domain_nodes()
self.cleanup_model()
self.set_ir_version(self.max_ir_version)
self.convert_fp64_to_fp32()

def convert_fp64_to_fp32(self) -> None:
"""Convert FP64 initializers, I/O types, and specific nodes to FP32."""
modified = False

# Convert initializers
if self._convert_fp64_initializers():
modified = True

# Convert input/output types
if self._convert_fp64_io_types():
modified = True

# Convert specific node types: Cast, ConstantOfShape, Constant
if self._convert_fp64_nodes():
modified = True

if modified:
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)

def find_custom_nodes(self) -> None:
"""Find custom nodes in the model.
Expand Down Expand Up @@ -405,6 +426,85 @@ def _get_initializer_value(self, name: str, return_array: bool = False) -> np.nd
return value if return_array else value.item()
return None

def _convert_fp64_initializers(self) -> bool:
"""Convert FP64 initializers to FP32.

Returns:
bool: True if any initializers were modified, False otherwise.
"""
modified = False

for initializer in self.model.graph.initializer:
if initializer.data_type == onnx.TensorProto.DOUBLE:
# Convert the data to FP32
fp64_data = numpy_helper.to_array(initializer)
fp32_data = fp64_data.astype(np.float32)

# Create new initializer with FP32 data
new_initializer = numpy_helper.from_array(fp32_data, name=initializer.name)

# Replace the old initializer
initializer.CopyFrom(new_initializer)
modified = True
logger.debug(f"Converted initializer {initializer.name} from FP64 to FP32")

return modified

def _convert_fp64_io_types(self) -> bool:
"""Convert FP64 input/output types to FP32.

Returns:
bool: True if any I/O types were modified, False otherwise.
"""
modified = False

def convert_tensor_list(tensors, tensor_type):
nonlocal modified
for tensor in tensors:
if tensor.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE:
tensor.type.tensor_type.elem_type = onnx.TensorProto.FLOAT
modified = True
logger.debug(f"Converted {tensor_type} {tensor.name} from FP64 to FP32")

convert_tensor_list(self.model.graph.input, "input")
convert_tensor_list(self.model.graph.output, "output")
convert_tensor_list(self.model.graph.value_info, "value_info")

return modified

def _convert_fp64_nodes(self) -> bool:
"""Convert specific node types from FP64 to FP32.

Handles Cast, ConstantOfShape, and Constant nodes that use FP64.

Returns:
bool: True if any nodes were modified, False otherwise.
"""
modified = False

for node in self.model.graph.node:
if node.op_type == "Cast":
# Check if casting to FP64, change to FP32
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.DOUBLE:
attr.i = onnx.TensorProto.FLOAT
modified = True
logger.debug(f"Converted Cast node {node.name} from FP64 to FP32")

elif node.op_type in ["ConstantOfShape", "Constant"]:
# Check if the value attribute uses FP64
for attr in node.attribute:
if attr.name == "value" and attr.t.data_type == onnx.TensorProto.DOUBLE:
# Convert the tensor value to FP32
fp64_data = numpy_helper.to_array(attr.t)
fp32_data = fp64_data.astype(np.float32)
new_tensor = numpy_helper.from_array(fp32_data)
attr.t.CopyFrom(new_tensor)
modified = True
logger.debug(f"Converted {node.op_type} node {node.name} from FP64 to FP32")

return modified

def cleanup_model(self) -> None:
"""Use GraphSurgeon to cleanup unused nodes, tensors and initializers."""
gs_graph = gs.import_onnx(self.model)
Expand Down
224 changes: 224 additions & 0 deletions tests/unit/onnx/autocast/test_graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,227 @@ def test_invalid_layernorm_pattern():

# Verify no LayerNorm transformation occurred
assert not any(node.op_type == "LayerNormalization" for node in sanitizer.model.graph.node)


def test_convert_fp64_initializers():
"""Test conversion of FP64 initializers to FP32."""
# Create a model with FP64 initializers
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])

# Create FP64 initializers
fp64_weights = np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], dtype=np.float64)
fp64_bias = np.array([0.1, 0.2, 0.3], dtype=np.float64)
fp32_weights = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

initializers = [
numpy_helper.from_array(fp64_weights, name="fp64_weights"),
numpy_helper.from_array(fp64_bias, name="fp64_bias"),
numpy_helper.from_array(fp32_weights, name="fp32_weights"),
]

# Verify the FP64 initializers have correct data type
assert initializers[0].data_type == TensorProto.DOUBLE
assert initializers[1].data_type == TensorProto.DOUBLE
assert initializers[2].data_type == TensorProto.FLOAT

add_node = helper.make_node("Add", ["X", "fp64_weights"], ["Y"])

graph = helper.make_graph(
nodes=[add_node], name="fp64_test", inputs=[x], outputs=[y], initializer=initializers
)

model = helper.make_model(graph)
sanitizer = GraphSanitizer(model)

# Test the conversion
result = sanitizer._convert_fp64_initializers()
assert result is True

# Verify all initializers are now FP32
for init in sanitizer.model.graph.initializer:
if init.name in ["fp64_weights", "fp64_bias"]:
assert init.data_type == TensorProto.FLOAT
# Verify data integrity
converted_data = numpy_helper.to_array(init)
assert converted_data.dtype == np.float32
elif init.name == "fp32_weights":
assert init.data_type == TensorProto.FLOAT


def test_convert_fp64_io_types():
"""Test conversion of FP64 input/output types to FP32."""
# Create inputs and outputs with FP64 types
x_fp64 = helper.make_tensor_value_info("X_fp64", TensorProto.DOUBLE, [2, 3])
y_fp64 = helper.make_tensor_value_info("Y_fp64", TensorProto.DOUBLE, [2, 3])
x_fp32 = helper.make_tensor_value_info("X_fp32", TensorProto.FLOAT, [2, 3])

# Create value_info with FP64 type
value_info_fp64 = helper.make_tensor_value_info("intermediate", TensorProto.DOUBLE, [2, 3])
value_info_fp32 = helper.make_tensor_value_info("intermediate2", TensorProto.FLOAT, [2, 3])

add_node = helper.make_node("Add", ["X_fp64", "X_fp32"], ["Y_fp64"])

graph = helper.make_graph(
nodes=[add_node],
name="fp64_io_test",
inputs=[x_fp64, x_fp32],
outputs=[y_fp64],
value_info=[value_info_fp64, value_info_fp32],
)

model = helper.make_model(graph)
sanitizer = GraphSanitizer(model)

# Test the conversion
result = sanitizer._convert_fp64_io_types()
assert result is True

# Verify inputs are converted
assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT
assert sanitizer.model.graph.input[1].type.tensor_type.elem_type == TensorProto.FLOAT

# Verify outputs are converted
assert sanitizer.model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT

# Verify value_info are converted
for vi in sanitizer.model.graph.value_info:
assert vi.type.tensor_type.elem_type == TensorProto.FLOAT


def test_convert_fp64_nodes():
"""Test conversion of specific node types from FP64 to FP32."""
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])

# Create FP64 constant tensor for ConstantOfShape and Constant nodes
fp64_value = numpy_helper.from_array(np.array([1.5], dtype=np.float64))
fp64_shape_value = numpy_helper.from_array(np.array([2.5], dtype=np.float64))

# Create nodes that use FP64
cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE)
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_value)
constant_shape_node = helper.make_node(
"ConstantOfShape", ["shape"], ["shape_out"], value=fp64_shape_value
)
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])

# Shape input for ConstantOfShape
shape_init = numpy_helper.from_array(np.array([2, 3], dtype=np.int64), name="shape")

graph = helper.make_graph(
nodes=[cast_node, constant_node, constant_shape_node, add_node],
name="fp64_nodes_test",
inputs=[x],
outputs=[y],
initializer=[shape_init],
)

model = helper.make_model(graph)
sanitizer = GraphSanitizer(model)

# Test the conversion
result = sanitizer._convert_fp64_nodes()
assert result is True

# Verify Cast node is converted
cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"]
assert len(cast_nodes) == 1
cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to")
assert cast_attr.i == TensorProto.FLOAT

# Verify Constant node is converted
constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"]
assert len(constant_nodes) == 1
const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value")
assert const_attr.t.data_type == TensorProto.FLOAT

# Verify ConstantOfShape node is converted
const_shape_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "ConstantOfShape"]
assert len(const_shape_nodes) == 1
shape_attr = next(attr for attr in const_shape_nodes[0].attribute if attr.name == "value")
assert shape_attr.t.data_type == TensorProto.FLOAT


def test_convert_fp64_to_fp32_integration():
"""Test the main convert_fp64_to_fp32 method with mixed FP64/FP32 content."""
# Create a model with mixed FP64 and FP32 content
x_fp64 = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [2, 3])
y_fp32 = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])

# FP64 initializer
fp64_weights = numpy_helper.from_array(
np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float64), name="weights"
)

# FP64 constant value
fp64_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float64))

# Create nodes
cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE)
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_const_value)
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])

graph = helper.make_graph(
nodes=[cast_node, constant_node, add_node],
name="mixed_fp64_test",
inputs=[x_fp64],
outputs=[y_fp32],
initializer=[fp64_weights],
)

model = helper.make_model(graph)
sanitizer = GraphSanitizer(model)

# Test the main conversion method
sanitizer.convert_fp64_to_fp32()

# Verify all FP64 content has been converted
# Check input types
assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT

# Check initializers
for init in sanitizer.model.graph.initializer:
assert init.data_type == TensorProto.FLOAT

# Check Cast node
cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"]
cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to")
assert cast_attr.i == TensorProto.FLOAT

# Check Constant node
constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"]
const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value")
assert const_attr.t.data_type == TensorProto.FLOAT


def test_convert_fp64_no_changes_needed():
"""Test that conversion methods return False when no FP64 content exists."""
# Create a model with only FP32 content
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])

fp32_weights = numpy_helper.from_array(
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), name="weights"
)
fp32_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float32))

cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.FLOAT)
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp32_const_value)
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])

graph = helper.make_graph(
nodes=[cast_node, constant_node, add_node],
name="fp32_only_test",
inputs=[x],
outputs=[y],
initializer=[fp32_weights],
)

model = helper.make_model(graph)
sanitizer = GraphSanitizer(model)

# Test that no conversions are needed
assert sanitizer._convert_fp64_initializers() is False
assert sanitizer._convert_fp64_io_types() is False
assert sanitizer._convert_fp64_nodes() is False
Loading