@@ -496,11 +496,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
496496 if input is None or output is None :
497497 return None
498498
499- # TODO(rama): Parts of the following logic (implementing type/shape inference
500- # for Cast op) should be unnecessary. Generic incremental shape-inference
501- # should handle this. Only the optimization to eliminate redundant Cast ops
502- # should be needed here.
503-
504499 input_dtype = _get_input_element_type (node , 0 )
505500 output_dtype = _get_int_attribute (node , "to" , None )
506501 if output_dtype is not None :
@@ -904,7 +899,7 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
904899
905900
906901def _merge_shapes (
907- preferred_shapes : ir .Shape | None , referenced_shapes : ir .Shape | None
902+ preferred_shape : ir .Shape | None , other_shape : ir .Shape | None
908903) -> ir .Shape | None :
909904 """Merge two shapes, preferring dimensions from preferred_shapes."""
910905
@@ -919,14 +914,14 @@ def merge_dims(dim1, dim2):
919914 return dim2
920915 return dim1
921916
922- if preferred_shapes is None :
923- return referenced_shapes
924- if referenced_shapes is None :
925- return preferred_shapes
926- if len (preferred_shapes ) != len (referenced_shapes ):
917+ if preferred_shape is None :
918+ return other_shape
919+ if other_shape is None :
920+ return preferred_shape
921+ if len (preferred_shape ) != len (other_shape ):
927922 raise ValueError ("Shapes must have the same rank." )
928923 return ir .Shape (
929- [merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (preferred_shapes , referenced_shapes )]
924+ [merge_dims (dim1 , dim2 ) for dim1 , dim2 in zip (preferred_shape , other_shape )]
930925 )
931926
932927
0 commit comments