@@ -351,8 +351,9 @@ def sanitize_io_casts(self) -> None:
351351 """
352352 model_input_names = {input .name for input in self .model .graph .input }
353353 model_output_names = {output .name for output in self .model .graph .output }
354- nodes_to_add = []
355- for node in self .model .graph .node :
354+ insertions : list [tuple [int , onnx .NodeProto ]] = []
355+
356+ for idx , node in enumerate (self .model .graph .node ):
356357 if (
357358 node .op_type == "Cast"
358359 and node .input
@@ -363,23 +364,19 @@ def sanitize_io_casts(self) -> None:
363364 # Unique per graph output to avoid collisions when multiple outputs are cast from the same input
364365 cast_output_name = node .output [0 ]
365366 cast_new_output_name = f"{ cast_output_name } __io_cast_src"
366- nodes_to_add .append (
367- helper .make_node (
368- "Identity" ,
369- inputs = [cast_new_output_name ],
370- outputs = [cast_output_name ],
371- name = f"{ node .name } __io_cast_identity" ,
372- )
367+ identity_node = helper .make_node (
368+ "Identity" ,
369+ inputs = [cast_new_output_name ],
370+ outputs = [cast_output_name ],
371+ name = f"{ node .name } __io_cast_identity" ,
373372 )
374373 # Rewire Cast to produce the new intermediate
375374 node .output [0 ] = cast_new_output_name
375+ insertions .append ((idx + 1 , identity_node ))
376376
377- for node in nodes_to_add :
378- self .model .graph .node .append (node )
379-
380- # Make sure the graph is topologically sorted
381- gs_graph = gs .import_onnx (self .model ).cleanup ().toposort ()
382- self .model = gs .export_onnx (gs_graph )
377+ # Insert Identities in-order right after their corresponding Casts
378+ for offset , (pos , id_node ) in enumerate (insertions ):
379+ self .model .graph .node .insert (pos + offset , id_node )
383380
384381 def _create_layernorm_node (self , pattern : dict ) -> onnx .NodeProto :
385382 """Create a LayerNormalization node with optional bias."""
0 commit comments