33from kirin import ir
44from kirin .analysis import Forward , TypeInference
55from kirin .dialects import func
6+ from kirin .ir .exception import ValidationError
67from kirin .analysis .forward import ForwardFrame
78
89from bloqade .analysis .address import (
1415from .lattice import QubitValidation
1516
1617
18+ class QubitValidationError (ValidationError ):
19+ """ValidationError that records which qubit and gate caused the violation."""
20+
21+ qubit_id : int
22+ gate_name : str
23+
24+ def __init__ (self , node : ir .IRNode , qubit_id : int , gate_name : str ):
25+ # message stored in ValidationError so formatting/hint() will include it
26+ super ().__init__ (node , f"Qubit[{ qubit_id } ] cloned at { gate_name } gate." )
27+ self .qubit_id = qubit_id
28+ self .gate_name = gate_name
29+
30+
1731class NoCloningValidation (Forward [QubitValidation ]):
1832 """
1933 Validates the no-cloning theorem by tracking qubit addresses.
@@ -26,7 +40,9 @@ class NoCloningValidation(Forward[QubitValidation]):
2640 _address_frame : ForwardFrame [Address ] = field (init = False )
2741 _type_frame : ForwardFrame = field (init = False )
2842 method : ir .Method
29- _validation_errors : list [str ] = field (default_factory = list , init = False )
43+ _validation_errors : list [QubitValidationError ] = field (
44+ default_factory = list , init = False
45+ )
3046
3147 def __init__ (self , mtd : ir .Method ):
3248 """
@@ -63,13 +79,9 @@ def get_qubit_addresses(self, addr: Address) -> frozenset[int]:
6379 case _:
6480 return frozenset ()
6581
66- def get_stmt_info (self , stmt : ir .Statement ) -> str :
67- """String Report about the statement for violation messages."""
68- if isinstance (stmt , func .Invoke ) and hasattr (stmt , "callee" ):
69- gate_name = stmt .callee .sym_name .upper ()
70- return f"{ gate_name } Gate"
71-
72- return f"{ stmt .__class__ .__name__ } @{ stmt } "
82+ def format_violation (self , qubit_id : int , gate_name : str ) -> str :
83+ """Return the violation string for a qubit + gate."""
84+ return f"Qubit[{ qubit_id } ] on { gate_name } Gate"
7385
7486 def eval_stmt_fallback (
7587 self , frame : ForwardFrame [QubitValidation ], stmt : ir .Statement
@@ -101,13 +113,13 @@ def eval_stmt_fallback(
101113
102114 seen : set [int ] = set ()
103115 violations : list [str ] = []
104- stmt_info = self .get_stmt_info (stmt )
105116
106117 for qubit_addr in used_addrs :
107118 if qubit_addr in seen :
108- violations .append (f"Qubit[{ qubit_addr } ] on { stmt_info } " )
119+ gate_name = stmt .callee .sym_name .upper ()
120+ violations .append (self .format_violation (qubit_addr , gate_name ))
109121 self ._validation_errors .append (
110- f"Qubit[ { qubit_addr } ] on { stmt_info } in { stmt . source } "
122+ QubitValidationError ( stmt , qubit_addr , gate_name )
111123 )
112124 seen .add (qubit_addr )
113125
@@ -123,6 +135,21 @@ def run_method(
123135 self_mt = self .method_self (method )
124136 return self .run_callable (method .code , (self_mt ,) + args )
125137
126- def get_validation_errors (self ) -> str :
127- """Retrieve collected validation error messages."""
128- return "\n " .join (self ._validation_errors )
138+ def raise_validation_errors (self ):
139+ """Raise validation error for each no-cloning violation found.
140+ Points to source file and line with snippet.
141+ """
142+ if not self ._validation_errors :
143+ return
144+
145+ # If multiple errors, print all with snippets first
146+ if len (self ._validation_errors ) > 1 :
147+ for err in self ._validation_errors :
148+ err .attach (self .method )
149+ # Print error message before snippet
150+ print (
151+ f"\033 [31mValidation Error\033 [0m: Cloned qubit [{ err .qubit_id } ] at { err .gate_name } gate."
152+ )
153+ print (err .hint ())
154+ print (f"Raised { len (self ._validation_errors )} error(s)." )
155+ raise
0 commit comments