@@ -71,7 +71,7 @@ def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidati
7171 def eval_fallback (
7272 self , frame : ForwardFrame [QubitValidation ], node : ir .Statement
7373 ) -> tuple [QubitValidation , ...]:
74- """Check for qubit usage violations."""
74+ """Check for qubit usage violations and return lattice values ."""
7575 if not isinstance (node , func .Invoke ):
7676 return tuple (Bottom () for _ in node .results )
7777
@@ -83,6 +83,7 @@ def eval_fallback(
8383 has_unknown = False
8484 has_qubit_args = False
8585 unknown_arg_names : list [str ] = []
86+
8687 for arg in node .args :
8788 addr = address_frame .get (arg )
8889 match addr :
@@ -110,34 +111,25 @@ def eval_fallback(
110111 return tuple (Bottom () for _ in node .results )
111112
112113 seen : set [int ] = set ()
113- must_violations : list [ str ] = []
114- s_name = getattr (node .callee , "sym_name" , "<unknown" )
114+ violations : set [ tuple [ int , str ]] = set ()
115+ s_name = getattr (node .callee , "sym_name" , "<unknown> " )
115116 gate_name = s_name .upper ()
116117
117118 for qubit_addr in concrete_addrs :
118119 if qubit_addr in seen :
119- violation = f"Qubit[{ qubit_addr } ] on { gate_name } Gate"
120- must_violations .append (violation )
121- self .add_validation_error (
122- node , QubitValidationError (node , qubit_addr , gate_name )
123- )
124-
120+ violations .add ((qubit_addr , gate_name ))
125121 seen .add (qubit_addr )
126122
127- if must_violations :
128- usage = Must (violations = frozenset (must_violations ))
123+ if violations :
124+ usage = Must (violations = frozenset (violations ))
129125 elif has_unknown :
130126 args_str = " == " .join (unknown_arg_names )
131127 if len (unknown_arg_names ) > 1 :
132128 condition = f", when { args_str } "
133129 else :
134130 condition = f", with unknown argument { args_str } "
135131
136- self .add_validation_error (
137- node , PotentialQubitValidationError (node , gate_name , condition )
138- )
139-
140- usage = May (violations = frozenset ([f"{ gate_name } Gate{ condition } " ]))
132+ usage = May (violations = frozenset ([(gate_name , condition )]))
141133 else :
142134 usage = Bottom ()
143135
@@ -159,6 +151,48 @@ def _get_source_name(self, value: ir.SSAValue) -> str:
159151
160152 return str (value )
161153
154+ def extract_errors_from_frame (
155+ self , frame : ForwardFrame [QubitValidation ]
156+ ) -> list [ValidationError ]:
157+ """Extract validation errors from final lattice values.
158+
159+ Only extracts errors from top-level statements (not nested in regions).
160+ """
161+ errors = []
162+ seen_statements = set ()
163+
164+ for node , value in frame .entries .items ():
165+ if isinstance (node , ir .ResultValue ):
166+ stmt = node .stmt
167+ elif isinstance (node , ir .Statement ):
168+ stmt = node
169+ else :
170+ continue
171+ if stmt in seen_statements :
172+ continue
173+ seen_statements .add (stmt )
174+ if isinstance (value , Must ):
175+ for qubit_id , gate_name in value .violations :
176+ errors .append (QubitValidationError (stmt , qubit_id , gate_name ))
177+ elif isinstance (value , May ):
178+ for gate_name , condition in value .violations :
179+ errors .append (
180+ PotentialQubitValidationError (stmt , gate_name , condition )
181+ )
182+ return errors
183+
184+ def count_violations (self , frame : Any ) -> int :
185+ """Count individual violations from the frame, same as test helper."""
186+ from .lattice import May , Must
187+
188+ total = 0
189+ for node , value in frame .entries .items ():
190+ if isinstance (value , Must ):
191+ total += len (value .violations )
192+ elif isinstance (value , May ):
193+ total += len (value .violations )
194+ return total
195+
162196
163197class NoCloningValidation (ValidationPass ):
164198 """Validates the no-cloning theorem by tracking qubit addresses."""
@@ -179,37 +213,39 @@ def set_analysis_cache(self, cache: dict[type, Any]) -> None:
179213 self ._cached_address_frame = cache .get (AddressAnalysis )
180214
181215 def run (self , method : ir .Method ) -> tuple [Any , list [ValidationError ]]:
182- """Run the no-cloning validation analysis.
183-
184- Returns:
185- - frame: ForwardFrame with QubitValidation lattice values
186- - errors: List of validation errors found
187- """
216+ """Run the no-cloning validation analysis."""
188217 if self ._analysis is None :
189218 self ._analysis = _NoCloningAnalysis (method .dialects )
190219
191220 self ._analysis .initialize ()
192221 if self ._cached_address_frame is not None :
193222 self ._analysis ._address_frame = self ._cached_address_frame
223+
194224 frame , _ = self ._analysis .run (method )
195- return frame , self ._analysis .get_validation_errors ()
225+ errors = self ._analysis .extract_errors_from_frame (frame )
226+
227+ return frame , errors
196228
197229 def print_validation_errors (self ):
198230 """Print all collected errors with formatted snippets."""
199231 if self ._analysis is None :
200232 return
201- validation_errors = self ._analysis .get_validation_errors ()
202- for err in validation_errors :
203- if isinstance (err , QubitValidationError ):
204- print (
205- f"\n \033 [31mError\033 [0m: Cloning qubit [{ err .qubit_id } ] at { err .gate_name } gate"
206- )
207- elif isinstance (err , PotentialQubitValidationError ):
208- print (
209- f"\n \033 [33mWarning\033 [0m: Potential cloning at { err .gate_name } gate{ err .condition } "
210- )
211- else :
212- print (
213- f"\n \033 [31mError\033 [0m: { err .args [0 ] if err .args else type (err ).__name__ } "
214- )
215- print (err .hint ())
233+
234+ if self ._analysis .state ._current_frame :
235+ frame = self ._analysis .state ._current_frame
236+ errors = self ._analysis .extract_errors_from_frame (frame )
237+
238+ for err in errors :
239+ if isinstance (err , QubitValidationError ):
240+ print (
241+ f"\n \033 [31mError\033 [0m: Cloning qubit [{ err .qubit_id } ] at { err .gate_name } gate"
242+ )
243+ elif isinstance (err , PotentialQubitValidationError ):
244+ print (
245+ f"\n \033 [33mWarning\033 [0m: Potential cloning at { err .gate_name } gate{ err .condition } "
246+ )
247+ else :
248+ print (
249+ f"\n \033 [31mError\033 [0m: { err .args [0 ] if err .args else type (err ).__name__ } "
250+ )
251+ print (err .hint ())
0 commit comments