Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 1 addition & 55 deletions onnxscript/rewriter/collapse_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,71 +71,17 @@ def _identity_to_itself(op, data, **_):
return op.Identity(data)


def _identity_to_updates(op, data, indices, updates, **_):
"""Return the updates as the output.

This is used when the ScatterND is redundant in terms of
updating the whole data with the updates.

"""
return op.Identity(updates)


def _potential_redundant_slice(op, data, starts, ends, axes, steps):
"""To identify a slice op"""
return op.Slice(data, starts, ends, axes, steps)


def _potential_redundant_scatternd(op, data, indices, updates):
"""To identify a ScatterND op"""
return op.ScatterND(data, indices, updates)


def _check_if_redundant_scatternd(
context,
data: ir.Value,
indices: ir.Value,
updates: ir.Value,
**_,
):
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value."""
del context # Reserved for future extensions

# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
if data.shape is None:
logger.info("The value 'data' shape is not statically known.")
return False
if updates.shape is None:
logger.info("The value 'updates' shape is not statically known.")
return False
if data.shape != updates.shape:
logger.info("The shape of 'data' and 'updates' are different.")
return False

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
logger.info("The value 'indices' is not statically known.")
return False
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type]
logger.info("The 'indices' is not referring to the whole data.")
return False

return True


# Register the rewrite rules
remove_redundant_slice = pattern.RewriteRule(
_potential_redundant_slice,
_identity_to_itself,
_check_if_redundant_slice,
)

remove_redundant_scatternd = pattern.RewriteRule(
_potential_redundant_scatternd,
_identity_to_updates,
_check_if_redundant_scatternd,
)

# NOTE: The order of the rules is important. Larger pattern should be checked first.
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd])
rules = pattern.RewriteRuleSet([remove_redundant_slice])
32 changes: 0 additions & 32 deletions onnxscript/rewriter/collapse_slices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,3 @@ def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
model = ir.serde.deserialize_model(model_proto)
count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 0)

def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
{
output = ScatterND (data, indices, updates)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
indices = np.arange(112).reshape(112, 1).astype(np.int64)
model = ir.serde.deserialize_model(model_proto)
# from numpy to ir.Tensor
indices_ir_tensor = ir.Tensor(
name="indices",
value=indices,
)
# assign the tensor to a value
indices = model.graph[0].inputs[1]
indices.const_value = indices_ir_tensor
model.graph.initializers["indices"] = indices
original_model_proto = ir.serde.serialize_model(model)

count = collapse_slices.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 1)
self.assertIn("Identity", [node.op_type for node in model.graph])

input = np.random.rand(112, 16, 512).astype(np.float32)
testing.assert_numerically_equal(original_model_proto, model, (input, input))
62 changes: 53 additions & 9 deletions onnxscript/rewriter/redundant_scatter_nd.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Rewrite rule to eliminate redundant ScatterND operations.
"""Rewrite rules to eliminate redundant ScatterND operations.

Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates).
This is generated by the translation of `x[:, ...] = y` in PyTorch.
The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension,
where S is the size of the first dimension of the updated-data tensor.
In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.
This module contains two rewrite rules:

1. ScatterAllDynamic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
when the indices are computed dynamically using Range operations but represent a complete update
of an entire axis. This is generated by the translation of `x[:, ...] = y` in PyTorch.

2. ScatterAllStatic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
when the indices are statically known constants in the form [[0], [1], ..., [n-1]] covering
the entire first dimension of the data tensor.

Both rules detect when the scatter-update ends up being an assignment of a new value to the entire tensor.
"""

from __future__ import annotations
Expand All @@ -22,7 +28,7 @@ def fail(*args):
return onnxscript.rewriter.MatchResult().fail(*args)


class ScatterAll(orp.RewriteRuleClassBase):
class ScatterAllDynamic(orp.RewriteRuleClassBase):
def pattern(self, op, data, axis, transposed_data, updates):
# Construct update-indices spanning an entire axis:
shape = op.Shape(data, start=0)
Expand Down Expand Up @@ -60,6 +66,44 @@ def rewrite(self, op, updates, **_):
return op.Identity(updates)


rule = ScatterAll.rule()
class ScatterAllStatic(orp.RewriteRuleClassBase):
"""Rewrite rule for eliminating redundant ScatterND with statically known indices.

This handles the case where indices are constant values in the form [[0], [1], ..., [n-1]]
that update the entire first dimension of the data tensor.
"""

def pattern(self, op, data, indices, updates):
"""Pattern to match ScatterND with static indices."""
return op.ScatterND(data, indices, updates)

def check(self, context, data, indices, updates, **_):
"""Check if the ScatterND is redundant due to static indices covering entire tensor."""
# To validate data can be replaced directly by updates, we need to check the following:
# 1. they have the same shape
if data.shape is None:
return fail("The value 'data' shape is not statically known.", data)
if updates.shape is None:
return fail("The value 'updates' shape is not statically known.", updates)
if data.shape != updates.shape:
return fail("The shape of 'data' and 'updates' are different.", data, updates)

# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
if indices.const_value is None:
return fail("The value 'indices' is not statically known.", indices)
expected_indices = [[i] for i in range(data.shape[0])]
actual_indices = indices.const_value.numpy().tolist()
if actual_indices != expected_indices:
return fail("The 'indices' is not referring to the whole data.", indices)

return True

def rewrite(self, op, updates, **_):
"""Replace ScatterND with Identity since updates covers entire tensor."""
return op.Identity(updates)


rule = ScatterAllDynamic.rule()
static_rule = ScatterAllStatic.rule()

rules = orp.RewriteRuleSet([rule])
rules = orp.RewriteRuleSet([rule, static_rule])
57 changes: 56 additions & 1 deletion onnxscript/rewriter/redundant_scatter_nd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import unittest

import numpy as np
import onnx.parser
import onnx_ir as ir
import onnxruntime
from onnx_ir.passes.common import CheckerPass, ShapeInferencePass
Expand All @@ -19,7 +20,9 @@


class RedundantScatterNdTest(unittest.TestCase):
def test_redundant_scatter_nd(self):
def test_redundant_scatter_nd_dynamic_indices(self):
"""Test redundant ScatterND with dynamically constructed indices."""

@script()
def model_script(
data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16]
Expand Down Expand Up @@ -62,9 +65,61 @@ def model_script(
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
optimized_outputs = optimized_session.run(None, inputs)
# Compare outputs
for output, optimized_output in zip(outputs, optimized_outputs):
np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6)

def test_redundant_scatter_nd_static_indices(self):
"""Test redundant ScatterND with static indices (moved from collapse_slices_test.py)."""
model_proto = onnx.parser.parse_model(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
{
output = ScatterND (data, indices, updates)
}
"""
)
# Use inserted initializers to avoid manually coding the large constants
indices = np.arange(112).reshape(112, 1).astype(np.int64)
model = ir.serde.deserialize_model(model_proto)
# from numpy to ir.Tensor
indices_ir_tensor = ir.Tensor(
name="indices",
value=indices,
)
# assign the tensor to a value
indices_value = model.graph[0].inputs[1]
indices_value.const_value = indices_ir_tensor
model.graph.initializers["indices"] = indices_value
original_model_proto = ir.serde.serialize_model(model)

count = redundant_scatter_nd.rules.apply_to_model(model)
self.assertEqual(count, 1)
self.assertEqual(len(model.graph), 1)
self.assertIn("Identity", [node.op_type for node in model.graph])

# Test numerical equivalence
input_data = np.random.rand(112, 16, 512).astype(np.float32)
inputs = {"data": input_data, "updates": input_data}

# Run original model
session = onnxruntime.InferenceSession(
original_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
original_outputs = session.run(None, inputs)

# Run optimized model
optimized_model_proto = ir.serde.serialize_model(model)
optimized_session = onnxruntime.InferenceSession(
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
)
optimized_outputs = optimized_session.run(None, inputs)

# Compare outputs
for original_output, optimized_output in zip(original_outputs, optimized_outputs):
np.testing.assert_allclose(original_output, optimized_output, rtol=1e-6, atol=1e-6)


if __name__ == "__main__":
unittest.main()
Loading