Skip to content

Commit 30d17e1

Browse files
authored
Merge branch 'main' into dis
2 parents ae4efc0 + 2ff01f7 commit 30d17e1

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,18 @@ def _get_numpy_value(
278278
if size_limit is not None and const_value.size > size_limit:
279279
return None
280280
try:
281-
# Reinterpret the array with `.view()` because some implementations of
282-
# ir.TensorProtocol (e.g. PyTorch<=2.7) do not use ml_dtypes for bfloat16 etc.
283-
array = const_value.numpy().view(const_value.dtype.numpy())
281+
# Turn the constant value into a numpy array representation with the
282+
# specifics of this conversion handled by the tensor type
283+
array = const_value.numpy()
284+
# Can/should not reinterpret strings via .view, resulting in
285+
# "TypeError: Cannot change data-type for array of references."
286+
# There is also no reason to reinterpret strings, this is only
287+
# relevant for some arithmetic types
288+
if const_value.dtype != ir.DataType.STRING:
289+
# Reinterpret the array with `.view()` because some
290+
# implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do
291+
# not use ml_dtypes for bfloat16 etc.
292+
array = array.view(const_value.dtype.numpy())
284293
except FileNotFoundError:
285294
# External data is not available.
286295
logger.warning(
@@ -344,6 +353,33 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
344353
return default
345354

346355

356+
@register("Add")
357+
def add(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
358+
"""Propagate symbolic dim values."""
359+
360+
def get_dim_value(input_index):
361+
input = _get_input(node, input_index)
362+
if input is None:
363+
return None
364+
shape_value: ir.Shape | None = state.get_shape_value(input)
365+
if shape_value is None or len(shape_value) != 1:
366+
return None
367+
dim: int | ir.SymbolicDim = shape_value[0]
368+
return dim if isinstance(dim, int) else dim.value
369+
370+
dim0 = get_dim_value(0)
371+
dim1 = get_dim_value(1)
372+
if dim0 is None or dim1 is None:
373+
return None
374+
if isinstance(dim0, int) and isinstance(dim1, int):
375+
result_dim_value: int | ir.SymbolicDim = dim0 + dim1
376+
else:
377+
result_dim_value = ir.SymbolicDim(f"{dim0}+{dim1}")
378+
output = _get_output(node, 0)
379+
if output is not None:
380+
state.set_sym_value(output, ir.Shape([result_dim_value]))
381+
382+
347383
@register("Abs")
348384
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
349385
"""Replace an Abs node by Identity when applicable.
@@ -392,9 +428,26 @@ def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
392428
return None
393429

394430

431+
def _propagate_shape_value(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
432+
"""Propagates symbolic shape value of input 0 to output 0.
433+
434+
Applies to ops like Reshape/Squeeze/Unsqueeze where the shape of the tensor may change
435+
but the values in the tensor remain the same.
436+
"""
437+
input = _get_input(node, 0)
438+
input_shape_value = state.get_shape_value(input)
439+
output = _get_output(node, 0)
440+
if output is not None and input_shape_value is not None:
441+
state.set_sym_value(output, input_shape_value)
442+
return None
443+
444+
395445
@register("Reshape")
396446
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
397-
"""Replace a Reshape node by Identity when applicable."""
447+
"""Replace a Reshape node by Identity when applicable.
448+
449+
Also propagate symbolic shape values.
450+
"""
398451
input = _get_input(node, 0)
399452
shape = _get_input(node, 1)
400453
if input is None or shape is None:
@@ -404,12 +457,18 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
404457
shape_value = state.get_shape_value(shape)
405458

406459
if shape_value is None or input_shape is None:
407-
return None
460+
return _propagate_shape_value(node, op, state)
408461

409462
# No need to check for special values like -1, 0, etc. here
410463
if _same_shape(input_shape, shape_value):
411464
return op.Identity(input)
412-
return None
465+
return _propagate_shape_value(node, op, state)
466+
467+
468+
@register("Squeeze")
469+
def squeeze(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
470+
"""Propagate symbolic shape values."""
471+
return _propagate_shape_value(node, op, state)
413472

414473

415474
@register("Cast")

onnxscript/rewriter/onnx_fusions/_rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2):
3030

3131
class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase):
3232
def __init__(self):
33-
super().__init__(name="RotaryEmbedding23", as_function=True)
33+
super().__init__(name="RotaryEmbedding23")
3434

3535
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
3636
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

0 commit comments

Comments
 (0)