@@ -332,20 +332,26 @@ def sanitize_io_casts(self) -> None:
332
332
model_output_names = {output .name for output in self .model .graph .output }
333
333
nodes_to_add = []
334
334
for node in self .model .graph .node :
335
- if node .op_type == "Cast" :
336
- if node .input [0 ] in model_input_names and node .output [0 ] in model_output_names :
337
- cast_input_name = node .input [0 ]
338
- cast_output_name = node .output [0 ]
339
- cast_new_output_name = cast_input_name + "_io_cast_identity"
340
- nodes_to_add .append (
341
- helper .make_node (
342
- "Identity" ,
343
- inputs = [cast_new_output_name ],
344
- outputs = [cast_output_name ],
345
- name = node .name + "_io_cast_identity" ,
346
- )
335
+ if (
336
+ node .op_type == "Cast"
337
+ and node .input
338
+ and node .output
339
+ and node .input [0 ] in model_input_names
340
+ and node .output [0 ] in model_output_names
341
+ ):
342
+ # Unique per graph output to avoid collisions when multiple outputs are cast from the same input
343
+ cast_output_name = node .output [0 ]
344
+ cast_new_output_name = f"{ cast_output_name } __io_cast_src"
345
+ nodes_to_add .append (
346
+ helper .make_node (
347
+ "Identity" ,
348
+ inputs = [cast_new_output_name ],
349
+ outputs = [cast_output_name ],
350
+ name = f"{ node .name } __io_cast_identity" ,
347
351
)
348
- node .output [0 ] = cast_new_output_name
352
+ )
353
+ # Rewire Cast to produce the new intermediate
354
+ node .output [0 ] = cast_new_output_name
349
355
350
356
for node in nodes_to_add :
351
357
self .model .graph .node .append (node )
0 commit comments