Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,15 @@ def convert_initializer(

def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
# handling only a single input and output, as we only remove cast nodes
"""
Bypass (remove) a Cast node by rewiring its producer(s) and consumer(s) in-place in the model graph.

This function expects the provided node to be a Cast with exactly one input and one output (asserted).
If the Cast's output is a graph output, the graph output name must be preserved: producers that
originally wrote the Cast input and any consumers of that input are rewired to produce/use the
graph output name instead. Otherwise, consumers of the Cast output are rewired to consume the
Cast input directly. Modifies self.model.graph in-place.
"""
assert len(node.input) == 1
assert len(node.output) == 1

Expand All @@ -576,6 +585,11 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
for i, prod_out in enumerate(producer.output):
if prod_out == input_tensor:
producer.output[i] = output_tensor
consumers = utils.get_consumer_nodes(self.model, input_tensor)
for consumer in consumers:
for i, input_name in enumerate(consumer.input):
if input_name == input_tensor:
consumer.input[i] = output_tensor
if (
not is_output_producer
): # Reconnect consumers of the cast output to use the cast input instead
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,98 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"


@pytest.fixture
def model_with_casted_output():
"""
Create a tiny ONNX model whose final outputs are produced by Cast nodes.

The graph:
- Input: "X" (float32, [2, 3]).
- A Constant tensor is added twice through Add nodes ("add1", "add2").
- Two Cast nodes ("cast1", "cast2") consume Add outputs and produce the graph outputs "Y1" and "Y2" (cast to FLOAT).
- Model uses opset 20 and has shapes inferred before being returned.

Returns:
tuple: (model, value_info_map, initializer_map, node_to_init_map)
- model (onnx.ModelProto): The checked ONNX model with inferred shapes.
- value_info_map (dict): Mapping from tensor name to ValueInfoProto.
- initializer_map (dict): Mapping from initializer name to TensorProto.
- node_to_init_map (dict): Mapping from node name to its related initializers.
"""
# Create input and outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output

# Create constant value
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)

# Create constant node
const_node = helper.make_node(
"Constant",
[],
["const"],
name="const",
value=numpy_helper.from_array(const, name="const_value"),
)

# Create computation nodes
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
add2 = helper.make_node("Add", ["add1_out", "const"], ["add2_out"], name="add2")

# Create cast nodes to higher precision (FLOAT32)
cast1 = helper.make_node("Cast", ["add1_out"], ["Y1"], name="cast1", to=TensorProto.FLOAT)
cast2 = helper.make_node("Cast", ["add2_out"], ["Y2"], name="cast2", to=TensorProto.FLOAT)

graph = helper.make_graph(
[const_node, add1, add2, cast1, cast2],
"model_with_casted_output",
[x],
[y1, y2],
[],
)

model = helper.make_model(graph, producer_name="model_with_casted_output")
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
onnx.save(model, "/tmp/model_with_casted_output.onnx")

return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
@pytest.mark.parametrize("keep_io_types", [True, False])
def test_casted_output_model(model_with_casted_output, low_precision_type, keep_io_types):
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
)

converted_model = converter.convert(
high_precision_nodes=["cast1", "cast2"], low_precision_nodes=["add1", "add2"]
)
onnx.checker.check_model(converted_model)

# Check that the output is casted to the correct precision
if keep_io_types:
assert converted_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT
assert converted_model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT
else:
assert converted_model.graph.output[
0
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)
assert converted_model.graph.output[
1
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)
Loading