Skip to content

Commit 23b2e12

Browse files
Copilotgramalingam
andcommitted
Fix tutorial example and add display method to PatternMatchContext
- Fix tutorial to use `x in model.graph.input` instead of checking input names - Add display method to PatternMatchContext with in_graph_order parameter - Fix test assertion to use model.graph[2] instead of model.graph.node[2] Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com>
1 parent d77b37a commit 23b2e12

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

docs/tutorial/rewriter/conditional_rewrite.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def advanced_condition_check(context, x, y, **_):
8282

8383
# Access the broader graph context and check that x occurs as a graph-input
8484
model = context.model
85-
input_names = [input.name for input in model.graph.input]
86-
if x not in input_names:
85+
if x not in model.graph.input:
8786
return False
8887

8988
# You can inspect the matched nodes for advanced validation

onnxscript/rewriter/_basics.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,28 @@ def nodes(self) -> Sequence[ir.Node]:
393393
"""All the nodes of the matching subgraph."""
394394
return self._match_result.nodes
395395

396+
def display(self, *, in_graph_order: bool = True) -> None:
397+
"""Display the nodes in the pattern match context.
398+
399+
Args:
400+
in_graph_order: If True, display nodes in the order they appear in the
401+
graph/function. If False, display nodes in the order they appear
402+
in the match result.
403+
"""
404+
nodes = self.nodes
405+
if not nodes:
406+
return
407+
408+
if in_graph_order:
409+
# Display nodes in same order as in graph/function
410+
for node in self._graph_or_function:
411+
if node in nodes:
412+
node.display()
413+
else:
414+
# Display nodes in match order
415+
for node in nodes:
416+
node.display()
417+
396418

397419
class MatchingTracer:
398420
"""A debugging helper class to trace the matching of a pattern against a graph.

onnxscript/rewriter/pattern_match_context_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def condition_using_context(context, x, y):
3131
# Use context to check properties of the match
3232
self.assertIs(context.model, model)
3333
self.assertIs(context.graph_or_function, model.graph)
34-
self.assertIs(context.main_root_node, model.graph.node[2])
34+
self.assertIs(context.main_root_node, model.graph[2])
3535

3636
# Verify that we can inspect the matched nodes
3737
self.assertEqual(len(context.nodes), 2)

0 commit comments

Comments
 (0)