diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index a007926c37..61dffab6f9 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -183,14 +183,14 @@ def _match_value( self, pattern_value: _pattern_ir.ValuePattern, value: ir.Value | None ) -> bool: """Match an IR value against a ValuePattern instance.""" - if value is not None and value.graph is not self._graph_or_function: + if value is not None and value.graph is not self._graph: if not isinstance( pattern_value, (_pattern_ir.Var, _pattern_ir.Constant, _pattern_ir.AnyValue) ): # If the pattern value is a Var, Constant, or AnyValue, we allow it to match # values from other graphs. Otherwise, we fail the match. return self.fail( - f"Value {value.name} is not in the graph {self._graph_or_function.name}. " + f"Value {value.name} is not in the graph {self._graph.name}. " f"Pattern matches crossing graph boundaries are not supported." ) if isinstance(pattern_value, _pattern_ir.AnyValue): @@ -362,7 +362,10 @@ def match( complications which require careful consideration. """ self._tracer = tracer - self._graph_or_function = graph_or_function[0].graph + if isinstance(graph_or_function, ir.Graph): + self._graph: ir.Graph = graph_or_function + else: + self._graph = graph_or_function.graph if self.pattern.has_single_output_node: self._init_match(verbose) return self._match_single_output_node(