Skip to content

Commit bda23ab

Browse files
committed
[Rewrite] Improve ReshapeReshape rule (#2301)
- remove pointless check in shape ignored - (conditional) support negative shape - (conditional) support zero shape
1 parent 4bd81b2 commit bda23ab

File tree

2 files changed

+113
-20
lines changed

2 files changed

+113
-20
lines changed

onnxscript/rewriter/basic_rules.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from typing import ClassVar, Sequence
1313

14+
import numpy as np
15+
1416
from onnxscript import ir
1517
from onnxscript.rewriter import _ir_utils as ir_utils
1618
from onnxscript.rewriter._basics import MatchResult
@@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape):
123125
return op.Reshape(op.Reshape(x, shape_ignored), shape)
124126

125127
def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
126-
return op.Reshape(x, shape)
128+
new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name))
129+
return op.Reshape(x, new_shape, allowzero=self._allowzero)
127130

128131
def check(self, context, x, shape_ignored, shape) -> MatchResult:
129132
check_result = MatchResult()
130-
if shape_ignored.const_value is None:
131-
return check_result.fail("Shape ignored is not a constant.")
132-
if shape.const_value is None:
133+
134+
# Shape must be a constant.
135+
if (np_shape := ir_utils.get_numpy_value(shape)) is None:
133136
return check_result.fail("Shape is not a constant.")
134-
if shape.const_value.numpy().min() <= 0:
135-
return check_result.fail("Shape has non-positive values.")
137+
# Convert to array to support assignment destination.
138+
self._new_shape = np.array(np_shape, np_shape.dtype)
139+
140+
# Try to replace {0,-1} values in shape if reshape output is known.
141+
if (reshape_output := context.output_values[0].shape) is not None:
142+
for i, dim in enumerate(reshape_output):
143+
if isinstance(dim, int) and dim > 0:
144+
self._new_shape[i] = dim
145+
146+
# Constraints for shape.
147+
self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0)
148+
if self._allowzero == 1 and any(self._new_shape == 0):
149+
return check_result
150+
if any(self._new_shape == 0) and any(self._new_shape < 0):
151+
return check_result.fail("Shape cannot contain both 0 and -1 dimensions.")
152+
elif np.count_nonzero(self._new_shape == 0) > 1:
153+
return check_result.fail("Shape cannot contain more than one 0 dimension.")
154+
155+
# At this point, we can safely replace '0' with '-1'.
156+
# Note allowzero is removed since at this point it does not have any effect.
157+
self._allowzero = None
158+
self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape)
136159
return check_result
137160

138161

onnxscript/rewriter/basic_rules_test.py

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,9 @@ def model3(X: ot.FLOAT[1, 1]):
414414

415415
class ReshapeReshapeTest(unittest.TestCase):
416416
@staticmethod
417-
def create_model(input_shape, shape1, shape2):
417+
def create_model(
418+
input_shape, shape1, shape2, allowzero1=0, allowzero2=0, infer_shape=False
419+
):
418420
def _convert_shape(shape, name):
419421
if isinstance(shape, np.ndarray):
420422
shape = tape.initializer(ir.Tensor(shape, name=name))
@@ -430,20 +432,43 @@ def _convert_shape(shape, name):
430432
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
431433

432434
# Build the graph.
433-
reshape = tape.op("Reshape", inputs=[x, _convert_shape(shape1, "shape_")])
434-
tape.op("Reshape", inputs=[reshape, _convert_shape(shape2, "shape")], output=y)
435+
reshape = tape.op(
436+
"Reshape",
437+
inputs=[x, _convert_shape(shape1, "shape_")],
438+
attributes={"allowzero": allowzero1},
439+
)
440+
tape.op(
441+
"Reshape",
442+
inputs=[reshape, _convert_shape(shape2, "shape")],
443+
attributes={"allowzero": allowzero2},
444+
output=y,
445+
)
435446
model = ir.Model(tape.graph_like, ir_version=10)
447+
448+
# Infer shapes.
449+
if infer_shape:
450+
model = ir.passes.common.ShapeInferencePass()(model).model
436451
return model
437452

438453
@parameterized.parameterized.expand(
439454
[
440455
((3, 4, 5), [4, 5, 3], [5, 4, 3]),
441456
((3, 4, 5), [4, 5, 3], [5, 4, 3]),
457+
((3, 4, 8), [2, 0, 3, -1], [0, 3, 2, 8]),
458+
((3, 4, 8), [3, 4, -1], [-1, 12], 1),
459+
((3, 4, 2), [0, 4, -1], [12, -1], 0, 1),
460+
((3, 0, 8), [4, 2, 0, 0], [3, 0], 1, 1),
442461
]
443462
)
444-
def test_reshape_reshape_rule(self, input_shape, shape1, shape2):
463+
def test_reshape_reshape_rule(
464+
self, input_shape, shape1, shape2, allowzero1=0, allowzero2=0
465+
):
445466
model = self.create_model(
446-
input_shape, np.array(shape1, dtype="int64"), np.array(shape2, dtype="int64")
467+
input_shape,
468+
np.array(shape1, dtype="int64"),
469+
np.array(shape2, dtype="int64"),
470+
allowzero1=allowzero1,
471+
allowzero2=allowzero2,
447472
)
448473
updated_model = clone_model(model)
449474

@@ -456,19 +481,64 @@ def test_reshape_reshape_rule(self, input_shape, shape1, shape2):
456481
inputs = np.random.default_rng(10).random(input_shape, dtype="float32")
457482
testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0)
458483

484+
@parameterized.parameterized.expand([([3, 2, 3, 3, 3], 1), ([0, -1, 3, 2], 0)])
485+
def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0):
486+
input_shape = (3, 6, 9)
487+
shape1 = np.array(shape1, dtype="int64")
488+
# Build the model with unknown shape1.
489+
model = self.create_model(
490+
input_shape,
491+
(shape1.size,),
492+
np.array((1, 6, 27), dtype="int64"),
493+
allowzero1=allowzero1,
494+
)
495+
updated_model = clone_model(model)
496+
497+
# check rewrite approach.
498+
count = basic_rules.reshape_reshape_rule.apply_to_model(updated_model)
499+
self.assertEqual(count, 1)
500+
self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph])
501+
502+
# Check inference.
503+
feeds = {
504+
"X": np.random.default_rng(2).random(input_shape, dtype="float32"),
505+
"shape_": shape1,
506+
}
507+
testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0)
508+
509+
@parameterized.parameterized.expand(
510+
[((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)]
511+
)
512+
def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0):
513+
# Note that shape inference is required for this test to be valid.
514+
shape2 = np.array(shape2, dtype="int64")
515+
model = self.create_model(
516+
input_shape,
517+
np.array((3, 2, -1), dtype="int64"),
518+
shape2,
519+
allowzero2=allowzero2,
520+
infer_shape=True,
521+
)
522+
updated_model = clone_model(model)
523+
524+
# check rewrite approach.
525+
count = basic_rules.reshape_reshape_rule.apply_to_model(updated_model)
526+
self.assertEqual(count, 1)
527+
self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph])
528+
529+
# Check inference.
530+
inputs = np.random.default_rng(7).random(input_shape, dtype="float32")
531+
testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0)
532+
459533
@parameterized.parameterized.expand(
460534
[
461-
((2,), np.array([1, 6], dtype="int64"), "ignored is not a constant"),
462-
(np.array([1, 6], dtype="int64"), (3,), "is not a constant"),
463-
(
464-
np.array([1, 6], dtype="int64"),
465-
np.array([0, 6], dtype="int64"),
466-
"non-positive values",
467-
),
535+
((3,), "is not a constant"),
536+
(np.array([0, -1], dtype="int64"), "both 0 and -1 dimensions"),
537+
(np.array([0, 0, 3], dtype="int64"), "more than one 0 dimension"),
468538
]
469539
)
470-
def test_unsupported_reshape_reshape(self, shape1, shape2, error_msg):
471-
model = self.create_model((1, 2, 3), shape1, shape2)
540+
def test_unsupported_reshape_reshape(self, shape2, error_msg):
541+
model = self.create_model((1, 2, 3), np.array([1, 6], dtype="int64"), shape2)
472542

473543
# Check rewrite approach.
474544
tracer = MatchingTracer()

0 commit comments

Comments
 (0)