Skip to content

Commit 22bed90

Browse files
committed
[Rewriter] Introduce Flatten to reshape (#2301)
- Convert Flatten to reshape if possible - Merge Flatten + Reshape or Reshape + Flatten
1 parent bda23ab commit 22bed90

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

onnxscript/rewriter/basic_rules.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
302302
return check_result
303303

304304

305+
class Flatten2Reshape(RewriteRuleClassBase):
306+
"""Convert ``Flatten(x)`` to Reshape."""
307+
308+
def pattern(self, op, x: ir.Value):
309+
return op.Flatten(x)
310+
311+
def rewrite(self, op, x: ir.Value):
312+
new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape"))
313+
return op.Reshape(x, new_shape)
314+
315+
def check(self, context, x: ir.Value) -> MatchResult:
316+
check_result = MatchResult()
317+
self._new_shape = np.array([-1, -1], "int64")
318+
319+
# Convert axis in a positive value if possible.
320+
axis = context.root.attributes.get_int("axis", 1)
321+
input_rank = None
322+
if (input_shape := x.shape) is not None:
323+
input_rank = len(input_shape)
324+
if axis < 0:
325+
axis += input_rank
326+
327+
# Compute reshape shape following axis attribute.
328+
if axis == 0:
329+
self._new_shape[0] = 1
330+
elif axis == 1:
331+
self._new_shape[0] = 0
332+
elif axis == input_rank:
333+
self._new_shape[1] = 1
334+
335+
# Try to update shape if output is known.
336+
if (output_shape := context.output_values[0].shape) is not None:
337+
for i, dim in enumerate(output_shape):
338+
if isinstance(dim, int):
339+
self._new_shape[i] = dim
340+
341+
# Try to update shape if input is known.
342+
if input_shape is not None:
343+
if all(isinstance(dim, int) for dim in input_shape[:axis]):
344+
self._new_shape[0] = np.prod(input_shape[:axis])
345+
if all(isinstance(dim, int) for dim in input_shape[axis:]):
346+
self._new_shape[1] = np.prod(input_shape[axis:])
347+
348+
# Verify if it is possible to apply rule.
349+
if np.count_nonzero(self._new_shape == -1) > 1:
350+
return check_result.fail("Impossible to compute new shape.")
351+
return check_result
352+
353+
305354
# Create rule instances
306355
cast_cast_rule = CastCast.rule()
307356
cast_identity_rule = CastIdentity.rule()
@@ -312,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult:
312361
transpose_transpose_rule = TransposeTranspose.rule()
313362
unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule()
314363
squeeze_reshape_1d_rule = SqueezeReshape.rule()
364+
flatten_to_reshape_rule = Flatten2Reshape.rule()
315365

316366

317367
def basic_optimization_rules() -> RewriteRuleSet:
@@ -334,6 +384,7 @@ def basic_optimization_rules() -> RewriteRuleSet:
334384
cast_cast_rule,
335385
cast_identity_rule,
336386
expand_identity_rule,
387+
flatten_to_reshape_rule,
337388
reshape_reshape_rule,
338389
slice_split_rule,
339390
transpose_identity_rule,

onnxscript/rewriter/basic_rules_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,5 +551,65 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
551551
self.assertRegex(tracer_match.match_result.reason, error_msg)
552552

553553

554+
class Flatten2ReshapeTest(unittest.TestCase):
555+
@staticmethod
556+
def create_model(input_shape, axis=1):
557+
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
558+
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
559+
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
560+
561+
# Build the graph.
562+
tape.op("Flatten", inputs=[x], attributes={"axis": axis}, output=y)
563+
model = ir.Model(tape.graph_like, ir_version=10)
564+
return model
565+
566+
@parameterized.parameterized.expand(list(range(-5, 6)))
567+
def test_flatten_to_reshape_rule(self, axis):
568+
input_shape = (1, 4, 8, 7, 5)
569+
model = self.create_model(input_shape=input_shape, axis=axis)
570+
updated_model = clone_model(model)
571+
572+
# check rewrite approach.
573+
count = basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model)
574+
self.assertEqual(count, 1)
575+
self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph])
576+
577+
# Check inference.
578+
inputs = np.random.default_rng(13).random(input_shape, dtype="float32")
579+
testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0)
580+
581+
@parameterized.parameterized.expand(list(range(-4, 5)))
582+
def test_flatten_to_reshape_dynamic_input(self, axis):
583+
model = self.create_model(input_shape=("N", "C1", "C2", "C3"), axis=axis)
584+
# Rule is supported in all cases if the output shape is known for non-special cases.
585+
input_shape = (1, 2, 3, 4)
586+
if axis not in {-3, 0, 1, 4}:
587+
out_shape = ir.Shape((np.prod(input_shape[:axis]), np.prod(input_shape[axis:])))
588+
model.graph.outputs[0].shape = out_shape
589+
updated_model = clone_model(model)
590+
591+
# check rewrite approach.
592+
count = basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model)
593+
self.assertEqual(count, 1)
594+
self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph])
595+
596+
# Check inference.
597+
inputs = np.random.default_rng(17).random(input_shape, dtype="float32")
598+
testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0)
599+
600+
def test_unsupported_flatten_to_reshape(self):
601+
model = self.create_model(input_shape=("N", "C1", "C2"), axis=2)
602+
603+
# Check rewrite approach.
604+
tracer = MatchingTracer()
605+
count = basic_rules.flatten_to_reshape_rule.apply_to_model(model, tracer=tracer)
606+
self.assertEqual(count, 0)
607+
608+
# Check that the error message is the expected one
609+
tracer_match = tracer.best_matches_map[basic_rules.flatten_to_reshape_rule][0]
610+
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
611+
self.assertRegex(tracer_match.match_result.reason, "Impossible to compute new shape")
612+
613+
554614
if __name__ == "__main__":
555615
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)