Skip to content

Commit ae6adeb

Browse files
📝 Add docstrings to aboubezari/fix_autocast_cast_output_producer_bug
Docstrings generation was requested by @aboubezari. * #302 (comment) The following files were modified: * `modelopt/onnx/autocast/precisionconverter.py` * `tests/unit/onnx/autocast/test_precisionconverter.py`
1 parent cf6f1d4 commit ae6adeb

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,15 @@ def convert_initializer(
559559

560560
def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
561561
# handling only a single input and output, as we only remove cast nodes
562+
"""
563+
Bypass (remove) a Cast node by rewiring its producer(s) and consumer(s) in-place in the model graph.
564+
565+
This function expects the provided node to be a Cast with exactly one input and one output (asserted).
566+
If the Cast's output is a graph output, the graph output name must be preserved: producers that
567+
originally wrote the Cast input and any consumers of that input are rewired to produce/use the
568+
graph output name instead. Otherwise, consumers of the Cast output are rewired to consume the
569+
Cast input directly. Modifies self.model.graph in-place.
570+
"""
562571
assert len(node.input) == 1
563572
assert len(node.output) == 1
564573

@@ -576,6 +585,11 @@ def _bypass_cast_node(self, node: onnx.NodeProto) -> None:
576585
for i, prod_out in enumerate(producer.output):
577586
if prod_out == input_tensor:
578587
producer.output[i] = output_tensor
588+
consumers = utils.get_consumer_nodes(self.model, input_tensor)
589+
for consumer in consumers:
590+
for i, input_name in enumerate(consumer.input):
591+
if input_name == input_tensor:
592+
consumer.input[i] = output_tensor
579593
if (
580594
not is_output_producer
581595
): # Reconnect consumers of the cast output to use the cast input instead

tests/unit/onnx/autocast/test_precisionconverter.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,98 @@ def test_constant_cast_folding(model_with_constant_cast_patterns, low_precision_
10231023
assert utils.get_consumer_nodes(converted_model, "const_scalar")[0].op_type == "Add"
10241024
assert len(utils.get_consumer_nodes(converted_model, "const_array")) == 1
10251025
assert utils.get_consumer_nodes(converted_model, "const_array")[0].op_type == "Add"
1026+
1027+
1028+
@pytest.fixture
1029+
def model_with_casted_output():
1030+
"""
1031+
Create a tiny ONNX model whose final outputs are produced by Cast nodes.
1032+
1033+
The graph:
1034+
- Input: "X" (float32, [2, 3]).
1035+
- A Constant tensor is added twice through Add nodes ("add1", "add2").
1036+
- Two Cast nodes ("cast1", "cast2") consume Add outputs and produce the graph outputs "Y1" and "Y2" (cast to FLOAT).
1037+
- Model uses opset 20 and has shapes inferred before being returned.
1038+
1039+
Returns:
1040+
tuple: (model, value_info_map, initializer_map, node_to_init_map)
1041+
- model (onnx.ModelProto): The checked ONNX model with inferred shapes.
1042+
- value_info_map (dict): Mapping from tensor name to ValueInfoProto.
1043+
- initializer_map (dict): Mapping from initializer name to TensorProto.
1044+
- node_to_init_map (dict): Mapping from node name to its related initializers.
1045+
"""
1046+
# Create input and outputs
1047+
x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3])
1048+
y1 = helper.make_tensor_value_info("Y1", TensorProto.FLOAT, [2, 3]) # Intermediate output
1049+
y2 = helper.make_tensor_value_info("Y2", TensorProto.FLOAT, [2, 3]) # Final output
1050+
1051+
# Create constant value
1052+
const = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1053+
1054+
# Create constant node
1055+
const_node = helper.make_node(
1056+
"Constant",
1057+
[],
1058+
["const"],
1059+
name="const",
1060+
value=numpy_helper.from_array(const, name="const_value"),
1061+
)
1062+
1063+
# Create computation nodes
1064+
add1 = helper.make_node("Add", ["X", "const"], ["add1_out"], name="add1")
1065+
add2 = helper.make_node("Add", ["add1_out", "const"], ["add2_out"], name="add2")
1066+
1067+
# Create cast nodes to higher precision (FLOAT32)
1068+
cast1 = helper.make_node("Cast", ["add1_out"], ["Y1"], name="cast1", to=TensorProto.FLOAT)
1069+
cast2 = helper.make_node("Cast", ["add2_out"], ["Y2"], name="cast2", to=TensorProto.FLOAT)
1070+
1071+
graph = helper.make_graph(
1072+
[const_node, add1, add2, cast1, cast2],
1073+
"model_with_casted_output",
1074+
[x],
1075+
[y1, y2],
1076+
[],
1077+
)
1078+
1079+
model = helper.make_model(graph, producer_name="model_with_casted_output")
1080+
model.opset_import[0].version = 20
1081+
model.ir_version = 10
1082+
onnx.checker.check_model(model)
1083+
1084+
model = onnx_utils.infer_shapes(model)
1085+
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
1086+
onnx.save(model, "/tmp/model_with_casted_output.onnx")
1087+
1088+
return model, value_info_map, initializer_map, node_to_init_map
1089+
1090+
1091+
@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"])
1092+
@pytest.mark.parametrize("keep_io_types", [True, False])
1093+
def test_casted_output_model(model_with_casted_output, low_precision_type, keep_io_types):
1094+
model, value_info_map, initializer_map, node_to_init_map = model_with_casted_output
1095+
1096+
converter = PrecisionConverter(
1097+
model,
1098+
value_info_map,
1099+
initializer_map,
1100+
node_to_init_map,
1101+
keep_io_types=keep_io_types,
1102+
low_precision_type=low_precision_type,
1103+
)
1104+
1105+
converted_model = converter.convert(
1106+
high_precision_nodes=["cast1", "cast2"], low_precision_nodes=["add1", "add2"]
1107+
)
1108+
onnx.checker.check_model(converted_model)
1109+
1110+
# Check that the output is casted to the correct precision
1111+
if keep_io_types:
1112+
assert converted_model.graph.output[0].type.tensor_type.elem_type == TensorProto.FLOAT
1113+
assert converted_model.graph.output[1].type.tensor_type.elem_type == TensorProto.FLOAT
1114+
else:
1115+
assert converted_model.graph.output[
1116+
0
1117+
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)
1118+
assert converted_model.graph.output[
1119+
1
1120+
].type.tensor_type.elem_type == low_precision_onnx_type(low_precision_type)

0 commit comments

Comments
 (0)