Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,11 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
for i, prod_out in enumerate(producer.output):
if prod_out == input_tensor:
producer.output[i] = output_tensor
consumers = utils.get_consumer_nodes(self.model, input_tensor)
for consumer in consumers:
for i, input_name in enumerate(consumer.input):
if input_name == input_tensor:
consumer.input[i] = output_tensor
if (
not is_output_producer
): # Reconnect consumers of the cast output to use the cast input instead
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/onnx/autocast/test_precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,83 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"


@pytest.fixture
def model_with_casted_output():
"""Create a model with an output produced by a Cast node."""
# Create input and outputs
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output

# Create constant value
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)

# Create constant node
const_node = helper.make_node(
"Constant",
[],
["const"],
name="const",
value=numpy_helper.from_array(const, name="const_value"),
)

# Create computation nodes
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
add2 = helper.make_node("Add", ["add1_out", "const"], ["add2_out"], name="add2")

# Create cast nodes to higher precision (FLOAT32)
cast1 = helper.make_node("Cast", ["add1_out"], ["Y1"], name="cast1", to=TensorProto.FLOAT)
cast2 = helper.make_node("Cast", ["add2_out"], ["Y2"], name="cast2", to=TensorProto.FLOAT)

graph = helper.make_graph(
[const_node, add1, add2, cast1, cast2],
"model_with_casted_output",
[x],
[y1, y2],
[],
)

model = helper.make_model(graph, producer_name="model_with_casted_output")
model.opset_import[0].version = 20
model.ir_version = 10
onnx.checker.check_model(model)

model = onnx_utils.infer_shapes(model)
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
onnx.save(model, "/tmp/model_with_casted_output.onnx")

return model, value_info_map, initializer_map, node_to_init_map


@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
@pytest.mark.parametrize("keep_io_types", [True, False])
def test_casted_output_model(model_with_casted_output, low_precision_type, keep_io_types):
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output

converter = PrecisionConverter(
model,
value_info_map,
initializer_map,
node_to_init_map,
keep_io_types=keep_io_types,
low_precision_type=low_precision_type,
)

converted_model = converter.convert(
high_precision_nodes=["cast1", "cast2"], low_precision_nodes=["add1", "add2"]
)
onnx.checker.check_model(converted_model)

# Check that the output is casted to the correct precision
if keep_io_types:
assert converted_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT
assert converted_model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT
else:
assert converted_model.graph.output[
0
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)
assert converted_model.graph.output[
1
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)