Skip to content

Commit 7b462dd

Browse files
Added check of the output names while removing cast nodes (#137)
Co-Authored-By: michaelfeil <[email protected]>
1 parent e33610c commit 7b462dd

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,12 +441,16 @@ def _remove_unnecessary_cast():
441441
cast_indices = []
442442

443443
tensor_consumers = get_tensor_consumer_nodes(graph)
444+
output_names = [output.name for output in graph.output]
444445

445446
# find all Cast node with same input and output type
446447
for node_idx, node in enumerate(graph.node):
447448
if node.op_type != "Cast":
448449
continue
449450

451+
if any(out_name in output_names for out_name in node.output):
452+
continue
453+
450454
# if input type matches attribute "to", this is a useless Cast node
451455
assert len(node.input) == 1
452456
input_name = node.input[0]

0 commit comments

Comments
 (0)