@@ -278,9 +278,18 @@ def _get_numpy_value(
278278 if size_limit is not None and const_value .size > size_limit :
279279 return None
280280 try :
281- # Reinterpret the array with `.view()` because some implementations of
282- # ir.TensorProtocol (e.g. PyTorch<=2.7) do not use ml_dtypes for bfloat16 etc.
283- array = const_value .numpy ().view (const_value .dtype .numpy ())
281+ # Turn the constant value into a numpy array representation with the
282+ # specifics of this conversion handled by the tensor type
283+ array = const_value .numpy ()
284+ # Can/should not reinterpret strings via .view, resulting in
285+ # "TypeError: Cannot change data-type for array of references."
286+ # There is also no reason to reinterpret strings, this is only
287+ # relevant for some arithmetic types
288+ if const_value .dtype != ir .DataType .STRING :
289+ # Reinterpret the array with `.view()` because some
290+ # implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do
291+ # not use ml_dtypes for bfloat16 etc.
292+ array = array .view (const_value .dtype .numpy ())
284293 except FileNotFoundError :
285294 # External data is not available.
286295 logger .warning (
@@ -344,6 +353,33 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
344353 return default
345354
346355
356+ @register ("Add" )
357+ def add (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
358+ """Propagate symbolic dim values."""
359+
360+ def get_dim_value (input_index ):
361+ input = _get_input (node , input_index )
362+ if input is None :
363+ return None
364+ shape_value : ir .Shape | None = state .get_shape_value (input )
365+ if shape_value is None or len (shape_value ) != 1 :
366+ return None
367+ dim : int | ir .SymbolicDim = shape_value [0 ]
368+ return dim if isinstance (dim , int ) else dim .value
369+
370+ dim0 = get_dim_value (0 )
371+ dim1 = get_dim_value (1 )
372+ if dim0 is None or dim1 is None :
373+ return None
374+ if isinstance (dim0 , int ) and isinstance (dim1 , int ):
375+ result_dim_value : int | ir .SymbolicDim = dim0 + dim1
376+ else :
377+ result_dim_value = ir .SymbolicDim (f"{ dim0 } +{ dim1 } " )
378+ output = _get_output (node , 0 )
379+ if output is not None :
380+ state .set_sym_value (output , ir .Shape ([result_dim_value ]))
381+
382+
347383@register ("Abs" )
348384def abs (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
349385 """Replace an Abs node by Identity when applicable.
@@ -392,9 +428,26 @@ def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
392428 return None
393429
394430
431+ def _propagate_shape_value (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
432+ """Propagates symbolic shape value of input 0 to output 0.
433+
434+ Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change
435+ but the values in the tensor remain the same.
436+ """
437+ input = _get_input (node , 0 )
438+ input_shape_value = state .get_shape_value (input )
439+ output = _get_output (node , 0 )
440+ if output is not None and input_shape_value is not None :
441+ state .set_sym_value (output , input_shape_value )
442+ return None
443+
444+
395445@register ("Reshape" )
396446def reshape (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
397- """Replace a Reshape node by Identity when applicable."""
447+ """Replace a Reshape node by Identity when applicable.
448+
449+ Also propagate symbolic shape values.
450+ """
398451 input = _get_input (node , 0 )
399452 shape = _get_input (node , 1 )
400453 if input is None or shape is None :
@@ -404,12 +457,18 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
404457 shape_value = state .get_shape_value (shape )
405458
406459 if shape_value is None or input_shape is None :
407- return None
460+ return _propagate_shape_value ( node , op , state )
408461
409462 # No need to check for special values like -1, 0, etc. here
410463 if _same_shape (input_shape , shape_value ):
411464 return op .Identity (input )
412- return None
465+ return _propagate_shape_value (node , op , state )
466+
467+
468+ @register ("Squeeze" )
469+ def squeeze (node : ir .Node , op , state : OptimizerState ) -> ReturnValue :
470+ """Propagate symbolic shape values."""
471+ return _propagate_shape_value (node , op , state )
413472
414473
415474@register ("Cast" )
0 commit comments