We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e33610c commit 7b462ddCopy full SHA for 7b462dd
modelopt/onnx/quantization/qdq_utils.py
@@ -441,12 +441,16 @@ def _remove_unnecessary_cast():
441
cast_indices = []
442
443
tensor_consumers = get_tensor_consumer_nodes(graph)
444
+ output_names = [output.name for output in graph.output]
445
446
# find all Cast node with same input and output type
447
for node_idx, node in enumerate(graph.node):
448
if node.op_type != "Cast":
449
continue
450
451
+ if any(out_name in output_names for out_name in node.output):
452
+ continue
453
+
454
# if input type matches attribute "to", this is a useless Cast node
455
assert len(node.input) == 1
456
input_name = node.input[0]
0 commit comments