Skip to content

Commit b921811

Browse files
committed
Added unittest
Signed-off-by: gcunhase <[email protected]>
1 parent 48d8344 commit b921811

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,81 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
10241024
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
10251025
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"
1026+
1027+
1028+
@pytest.fixture
1029+
def model_with_multiple_output_node_casted_to_output():
1030+
"""Create a model with a Cast node connecting a consumer with multiple outputs to a graph output."""
1031+
# Create inputs and outputs
1032+
x1 = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 2, 16, 16])
1033+
x2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 16, 16])
1034+
x3 = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 4, 16, 16])
1035+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [1, 5, 16, 16])
1036+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [1, 9, 16, 16])
1037+
1038+
# Create computation nodes
1039+
concat_1_node = helper.make_node(
1040+
"Concat",
1041+
["X1", "X2"],
1042+
["concat_1_out"],
1043+
name="concat_1",
1044+
axis=1,
1045+
)
1046+
concat_2_node = helper.make_node(
1047+
"Concat",
1048+
["concat_1_out", "X3"],
1049+
["Y2"],
1050+
name="concat_2",
1051+
axis=1,
1052+
)
1053+
1054+
# Create a Cast node between 'concat_1' and the graph output
1055+
cast_node = helper.make_node(
1056+
"Cast",
1057+
["concat_1_out"],
1058+
["Y1"],
1059+
name="cast_0",
1060+
to=TensorProto.FLOAT,
1061+
)
1062+
1063+
graph = helper.make_graph(
1064+
[concat_1_node, concat_2_node, cast_node],
1065+
"model_with_multiple_output_node_casted_to_output",
1066+
[x1, x2, x3],
1067+
[y1, y2],
1068+
[],
1069+
)
1070+
1071+
model = helper.make_model(
1072+
graph, producer_name="model_with_multiple_output_node_casted_to_output"
1073+
)
1074+
model.opset_import[0].version = 20
1075+
model.ir_version = 10
1076+
onnx.checker.check_model(model)
1077+
1078+
model = onnx_utils.infer_shapes(model)
1079+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1080+
1081+
return model, value_info_map, initializer_map, node_to_init_map
1082+
1083+
1084+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1085+
def test_multiple_output_node_casted_to_output(
1086+
model_with_multiple_output_node_casted_to_output, low_precision_type
1087+
):
1088+
model, value_info_map, initializer_map, node_to_init_map = (
1089+
model_with_multiple_output_node_casted_to_output
1090+
)
1091+
1092+
converter = PrecisionConverter(
1093+
model,
1094+
value_info_map,
1095+
initializer_map,
1096+
node_to_init_map,
1097+
keep_io_types=True,
1098+
low_precision_type=low_precision_type,
1099+
)
1100+
converted_model = converter.convert(
1101+
high_precision_nodes=[], low_precision_nodes=["concat_1", "concat_2"]
1102+
)
1103+
onnx.checker.check_model(converted_model)

0 commit comments

Comments
 (0)