Skip to content

Commit 552e58a

Browse files
Copilotgramalingam
andcommitted
Fix lint errors by running lintrunner formatting
Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
1 parent 97c3e61 commit 552e58a

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
pattern,
3333
redundant_scatter_nd,
3434
)
35-
from onnxscript.rewriter._basics import MatchingTracer, MatchResult, MatchStatus, MatchContext
35+
from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus
3636
from onnxscript.rewriter._rewrite_rule import (
3737
RewriterContext,
3838
RewriteRule,

onnxscript/rewriter/_basics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def print(self):
342342

343343
class MatchContext:
344344
"""A read-only context containing information about a pattern match.
345-
345+
346346
This class captures information about the context describing a match to a given pattern,
347347
providing access to the model, graph/function, root node, output values, and all
348348
nodes of the matching subgraph.
@@ -356,7 +356,7 @@ def __init__(
356356
match_result: MatchResult,
357357
) -> None:
358358
"""Initialize the pattern match context.
359-
359+
360360
Args:
361361
model: The model being matched.
362362
graph_or_function: The graph or function being matched.
@@ -395,7 +395,7 @@ def nodes(self) -> Sequence[ir.Node]:
395395

396396
def display(self, *, in_graph_order: bool = True) -> None:
397397
"""Display the nodes in the pattern match context.
398-
398+
399399
Args:
400400
in_graph_order: If True, display nodes in the order they appear in the
401401
graph/function. If False, display nodes in the order they appear
@@ -404,7 +404,7 @@ def display(self, *, in_graph_order: bool = True) -> None:
404404
nodes = self.nodes
405405
if not nodes:
406406
return
407-
407+
408408
if in_graph_order:
409409
# Display nodes in same order as in graph/function
410410
for node in self._graph_or_function:

onnxscript/rewriter/match_context_test.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class MatchContextTest(unittest.TestCase):
1414
def test_context_usage_in_condition_function(self):
1515
"""Test that MatchContext can be meaningfully used in condition functions."""
16-
16+
1717
model_proto = onnx.parser.parse_model(
1818
"""
1919
<ir_version: 7, opset_import: [ "" : 17]>
@@ -26,33 +26,31 @@ def test_context_usage_in_condition_function(self):
2626
"""
2727
)
2828
model = ir.serde.deserialize_model(model_proto)
29-
29+
3030
def condition_using_context(context, x, y):
3131
# Use context to check properties of the match
3232
self.assertIs(context.model, model)
3333
self.assertIs(context.graph_or_function, model.graph)
3434
self.assertIs(context.root, model.graph[2])
35-
35+
3636
# Verify that we can inspect the matched nodes
3737
self.assertEqual(len(context.nodes), 2)
38-
38+
3939
return True # Allow the rewrite
40-
40+
4141
def reciprocal_mul_pattern(op, x, y):
4242
return (1 / x) * y
4343

4444
def replacement(op, x, y):
4545
return op.Div(y, x)
4646

4747
rule = pattern.RewriteRule(
48-
reciprocal_mul_pattern,
49-
replacement,
50-
condition_function=condition_using_context
48+
reciprocal_mul_pattern, replacement, condition_function=condition_using_context
5149
)
52-
50+
5351
count = rule.apply_to_model(model)
5452
self.assertEqual(count, 1)
5553

5654

57-
if __name__ == '__main__':
58-
unittest.main()
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)