Skip to content

Commit a041bbe

Browse files
authored
[5532019][AutoCast] Update GraphSanitizer to convert Double to Float32 (NVIDIA#364)
Signed-off-by: ajrasane <[email protected]>
1 parent b4d6ced commit a041bbe

File tree

4 files changed

+332
-0
lines changed

4 files changed

+332
-0
lines changed

examples/onnx_ptq/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ Model Optimizer enables highly performant quantization formats including NVFP4,
2626

2727
Please use the TensorRT docker image (e.g., `nvcr.io/nvidia/tensorrt:25.08-py3`) or visit our [installation docs](https://nvidia.github.io/TensorRT-Model-Optimizer/getting_started/2_installation.html) for more information.
2828

29+
Set the following environment variables inside the TensorRT docker.
30+
31+
```bash
32+
export CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
33+
export LD_LIBRARY_PATH="${CUDNN_LIB_DIR}:${LD_LIBRARY_PATH}"
34+
```
35+
2936
Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install example-specific dependencies.
3037

3138
### Local Installation

modelopt/onnx/autocast/convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def convert_to_f16(
179179
sanitizer.find_custom_nodes()
180180
sanitizer.convert_opset()
181181
sanitizer.ensure_graph_name_exists()
182+
sanitizer.convert_fp64_to_fp32()
182183
model = sanitizer.model
183184

184185
# Setup internal mappings

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,27 @@ def sanitize(self) -> None:
6565
self.replace_custom_domain_nodes()
6666
self.cleanup_model()
6767
self.set_ir_version(self.max_ir_version)
68+
self.convert_fp64_to_fp32()
69+
70+
def convert_fp64_to_fp32(self) -> None:
71+
"""Convert FP64 initializers, I/O types, and specific nodes to FP32."""
72+
modified = False
73+
74+
# Convert initializers
75+
if self._convert_fp64_initializers():
76+
modified = True
77+
78+
# Convert input/output types
79+
if self._convert_fp64_io_types():
80+
modified = True
81+
82+
# Convert specific node types: Cast, ConstantOfShape, Constant
83+
if self._convert_fp64_nodes():
84+
modified = True
85+
86+
if modified:
87+
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
88+
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
6889

6990
def find_custom_nodes(self) -> None:
7091
"""Find custom nodes in the model.
@@ -405,6 +426,85 @@ def _get_initializer_value(self, name: str, return_array: bool = False) -> np.nd
405426
return value if return_array else value.item()
406427
return None
407428

429+
def _convert_fp64_initializers(self) -> bool:
430+
"""Convert FP64 initializers to FP32.
431+
432+
Returns:
433+
bool: True if any initializers were modified, False otherwise.
434+
"""
435+
modified = False
436+
437+
for initializer in self.model.graph.initializer:
438+
if initializer.data_type == onnx.TensorProto.DOUBLE:
439+
# Convert the data to FP32
440+
fp64_data = numpy_helper.to_array(initializer)
441+
fp32_data = fp64_data.astype(np.float32)
442+
443+
# Create new initializer with FP32 data
444+
new_initializer = numpy_helper.from_array(fp32_data, name=initializer.name)
445+
446+
# Replace the old initializer
447+
initializer.CopyFrom(new_initializer)
448+
modified = True
449+
logger.debug(f"Converted initializer {initializer.name} from FP64 to FP32")
450+
451+
return modified
452+
453+
def _convert_fp64_io_types(self) -> bool:
454+
"""Convert FP64 input/output types to FP32.
455+
456+
Returns:
457+
bool: True if any I/O types were modified, False otherwise.
458+
"""
459+
modified = False
460+
461+
def convert_tensor_list(tensors, tensor_type):
462+
nonlocal modified
463+
for tensor in tensors:
464+
if tensor.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE:
465+
tensor.type.tensor_type.elem_type = onnx.TensorProto.FLOAT
466+
modified = True
467+
logger.debug(f"Converted {tensor_type} {tensor.name} from FP64 to FP32")
468+
469+
convert_tensor_list(self.model.graph.input, "input")
470+
convert_tensor_list(self.model.graph.output, "output")
471+
convert_tensor_list(self.model.graph.value_info, "value_info")
472+
473+
return modified
474+
475+
def _convert_fp64_nodes(self) -> bool:
476+
"""Convert specific node types from FP64 to FP32.
477+
478+
Handles Cast, ConstantOfShape, and Constant nodes that use FP64.
479+
480+
Returns:
481+
bool: True if any nodes were modified, False otherwise.
482+
"""
483+
modified = False
484+
485+
for node in self.model.graph.node:
486+
if node.op_type == "Cast":
487+
# Check if casting to FP64, change to FP32
488+
for attr in node.attribute:
489+
if attr.name == "to" and attr.i == onnx.TensorProto.DOUBLE:
490+
attr.i = onnx.TensorProto.FLOAT
491+
modified = True
492+
logger.debug(f"Converted Cast node {node.name} from FP64 to FP32")
493+
494+
elif node.op_type in ["ConstantOfShape", "Constant"]:
495+
# Check if the value attribute uses FP64
496+
for attr in node.attribute:
497+
if attr.name == "value" and attr.t.data_type == onnx.TensorProto.DOUBLE:
498+
# Convert the tensor value to FP32
499+
fp64_data = numpy_helper.to_array(attr.t)
500+
fp32_data = fp64_data.astype(np.float32)
501+
new_tensor = numpy_helper.from_array(fp32_data)
502+
attr.t.CopyFrom(new_tensor)
503+
modified = True
504+
logger.debug(f"Converted {node.op_type} node {node.name} from FP64 to FP32")
505+
506+
return modified
507+
408508
def cleanup_model(self) -> None:
409509
"""Use GraphSurgeon to cleanup unused nodes, tensors and initializers."""
410510
gs_graph = gs.import_onnx(self.model)

tests/unit/onnx/autocast/test_graphsanitizer.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,227 @@ def test_invalid_layernorm_pattern():
183183

184184
# Verify no LayerNorm transformation occurred
185185
assert not any(node.op_type == "LayerNormalization" for node in sanitizer.model.graph.node)
186+
187+
188+
def test_convert_fp64_initializers():
189+
"""Test conversion of FP64 initializers to FP32."""
190+
# Create a model with FP64 initializers
191+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
192+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
193+
194+
# Create FP64 initializers
195+
fp64_weights = np.array([[1.5, 2.5, 3.5], [4.5, 5.5, 6.5]], dtype=np.float64)
196+
fp64_bias = np.array([0.1, 0.2, 0.3], dtype=np.float64)
197+
fp32_weights = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
198+
199+
initializers = [
200+
numpy_helper.from_array(fp64_weights, name="fp64_weights"),
201+
numpy_helper.from_array(fp64_bias, name="fp64_bias"),
202+
numpy_helper.from_array(fp32_weights, name="fp32_weights"),
203+
]
204+
205+
# Verify the FP64 initializers have correct data type
206+
assert initializers[0].data_type == TensorProto.DOUBLE
207+
assert initializers[1].data_type == TensorProto.DOUBLE
208+
assert initializers[2].data_type == TensorProto.FLOAT
209+
210+
add_node = helper.make_node("Add", ["X", "fp64_weights"], ["Y"])
211+
212+
graph = helper.make_graph(
213+
nodes=[add_node], name="fp64_test", inputs=[x], outputs=[y], initializer=initializers
214+
)
215+
216+
model = helper.make_model(graph)
217+
sanitizer = GraphSanitizer(model)
218+
219+
# Test the conversion
220+
result = sanitizer._convert_fp64_initializers()
221+
assert result is True
222+
223+
# Verify all initializers are now FP32
224+
for init in sanitizer.model.graph.initializer:
225+
if init.name in ["fp64_weights", "fp64_bias"]:
226+
assert init.data_type == TensorProto.FLOAT
227+
# Verify data integrity
228+
converted_data = numpy_helper.to_array(init)
229+
assert converted_data.dtype == np.float32
230+
elif init.name == "fp32_weights":
231+
assert init.data_type == TensorProto.FLOAT
232+
233+
234+
def test_convert_fp64_io_types():
235+
"""Test conversion of FP64 input/output types to FP32."""
236+
# Create inputs and outputs with FP64 types
237+
x_fp64 = helper.make_tensor_value_info("X_fp64", TensorProto.DOUBLE, [2, 3])
238+
y_fp64 = helper.make_tensor_value_info("Y_fp64", TensorProto.DOUBLE, [2, 3])
239+
x_fp32 = helper.make_tensor_value_info("X_fp32", TensorProto.FLOAT, [2, 3])
240+
241+
# Create value_info with FP64 type
242+
value_info_fp64 = helper.make_tensor_value_info("intermediate", TensorProto.DOUBLE, [2, 3])
243+
value_info_fp32 = helper.make_tensor_value_info("intermediate2", TensorProto.FLOAT, [2, 3])
244+
245+
add_node = helper.make_node("Add", ["X_fp64", "X_fp32"], ["Y_fp64"])
246+
247+
graph = helper.make_graph(
248+
nodes=[add_node],
249+
name="fp64_io_test",
250+
inputs=[x_fp64, x_fp32],
251+
outputs=[y_fp64],
252+
value_info=[value_info_fp64, value_info_fp32],
253+
)
254+
255+
model = helper.make_model(graph)
256+
sanitizer = GraphSanitizer(model)
257+
258+
# Test the conversion
259+
result = sanitizer._convert_fp64_io_types()
260+
assert result is True
261+
262+
# Verify inputs are converted
263+
assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT
264+
assert sanitizer.model.graph.input[1].type.tensor_type.elem_type == TensorProto.FLOAT
265+
266+
# Verify outputs are converted
267+
assert sanitizer.model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT
268+
269+
# Verify value_info are converted
270+
for vi in sanitizer.model.graph.value_info:
271+
assert vi.type.tensor_type.elem_type == TensorProto.FLOAT
272+
273+
274+
def test_convert_fp64_nodes():
275+
"""Test conversion of specific node types from FP64 to FP32."""
276+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
277+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
278+
279+
# Create FP64 constant tensor for ConstantOfShape and Constant nodes
280+
fp64_value = numpy_helper.from_array(np.array([1.5], dtype=np.float64))
281+
fp64_shape_value = numpy_helper.from_array(np.array([2.5], dtype=np.float64))
282+
283+
# Create nodes that use FP64
284+
cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE)
285+
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_value)
286+
constant_shape_node = helper.make_node(
287+
"ConstantOfShape", ["shape"], ["shape_out"], value=fp64_shape_value
288+
)
289+
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])
290+
291+
# Shape input for ConstantOfShape
292+
shape_init = numpy_helper.from_array(np.array([2, 3], dtype=np.int64), name="shape")
293+
294+
graph = helper.make_graph(
295+
nodes=[cast_node, constant_node, constant_shape_node, add_node],
296+
name="fp64_nodes_test",
297+
inputs=[x],
298+
outputs=[y],
299+
initializer=[shape_init],
300+
)
301+
302+
model = helper.make_model(graph)
303+
sanitizer = GraphSanitizer(model)
304+
305+
# Test the conversion
306+
result = sanitizer._convert_fp64_nodes()
307+
assert result is True
308+
309+
# Verify Cast node is converted
310+
cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"]
311+
assert len(cast_nodes) == 1
312+
cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to")
313+
assert cast_attr.i == TensorProto.FLOAT
314+
315+
# Verify Constant node is converted
316+
constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"]
317+
assert len(constant_nodes) == 1
318+
const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value")
319+
assert const_attr.t.data_type == TensorProto.FLOAT
320+
321+
# Verify ConstantOfShape node is converted
322+
const_shape_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "ConstantOfShape"]
323+
assert len(const_shape_nodes) == 1
324+
shape_attr = next(attr for attr in const_shape_nodes[0].attribute if attr.name == "value")
325+
assert shape_attr.t.data_type == TensorProto.FLOAT
326+
327+
328+
def test_convert_fp64_to_fp32_integration():
329+
"""Test the main convert_fp64_to_fp32 method with mixed FP64/FP32 content."""
330+
# Create a model with mixed FP64 and FP32 content
331+
x_fp64 = helper.make_tensor_value_info("X", TensorProto.DOUBLE, [2, 3])
332+
y_fp32 = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
333+
334+
# FP64 initializer
335+
fp64_weights = numpy_helper.from_array(
336+
np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float64), name="weights"
337+
)
338+
339+
# FP64 constant value
340+
fp64_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float64))
341+
342+
# Create nodes
343+
cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.DOUBLE)
344+
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp64_const_value)
345+
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])
346+
347+
graph = helper.make_graph(
348+
nodes=[cast_node, constant_node, add_node],
349+
name="mixed_fp64_test",
350+
inputs=[x_fp64],
351+
outputs=[y_fp32],
352+
initializer=[fp64_weights],
353+
)
354+
355+
model = helper.make_model(graph)
356+
sanitizer = GraphSanitizer(model)
357+
358+
# Test the main conversion method
359+
sanitizer.convert_fp64_to_fp32()
360+
361+
# Verify all FP64 content has been converted
362+
# Check input types
363+
assert sanitizer.model.graph.input[0].type.tensor_type.elem_type == TensorProto.FLOAT
364+
365+
# Check initializers
366+
for init in sanitizer.model.graph.initializer:
367+
assert init.data_type == TensorProto.FLOAT
368+
369+
# Check Cast node
370+
cast_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Cast"]
371+
cast_attr = next(attr for attr in cast_nodes[0].attribute if attr.name == "to")
372+
assert cast_attr.i == TensorProto.FLOAT
373+
374+
# Check Constant node
375+
constant_nodes = [n for n in sanitizer.model.graph.node if n.op_type == "Constant"]
376+
const_attr = next(attr for attr in constant_nodes[0].attribute if attr.name == "value")
377+
assert const_attr.t.data_type == TensorProto.FLOAT
378+
379+
380+
def test_convert_fp64_no_changes_needed():
381+
"""Test that conversion methods return False when no FP64 content exists."""
382+
# Create a model with only FP32 content
383+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
384+
y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
385+
386+
fp32_weights = numpy_helper.from_array(
387+
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), name="weights"
388+
)
389+
fp32_const_value = numpy_helper.from_array(np.array([0.5], dtype=np.float32))
390+
391+
cast_node = helper.make_node("Cast", ["X"], ["cast_out"], to=TensorProto.FLOAT)
392+
constant_node = helper.make_node("Constant", [], ["const_out"], value=fp32_const_value)
393+
add_node = helper.make_node("Add", ["cast_out", "const_out"], ["Y"])
394+
395+
graph = helper.make_graph(
396+
nodes=[cast_node, constant_node, add_node],
397+
name="fp32_only_test",
398+
inputs=[x],
399+
outputs=[y],
400+
initializer=[fp32_weights],
401+
)
402+
403+
model = helper.make_model(graph)
404+
sanitizer = GraphSanitizer(model)
405+
406+
# Test that no conversions are needed
407+
assert sanitizer._convert_fp64_initializers() is False
408+
assert sanitizer._convert_fp64_io_types() is False
409+
assert sanitizer._convert_fp64_nodes() is False

0 commit comments

Comments
 (0)