1010 Address ,
1111 AddressAnalysis ,
1212)
13- from bloqade .analysis .address .lattice import AddressReg , AddressQubit
13+ from bloqade .analysis .address .lattice import (
14+ Unknown ,
15+ AddressReg ,
16+ UnknownReg ,
17+ AddressQubit ,
18+ PartialIList ,
19+ PartialTuple ,
20+ UnknownQubit ,
21+ )
1422
15- from .lattice import QubitValidation
23+ from .lattice import May , Top , Must , Bottom , QubitValidation
1624
1725
1826class QubitValidationError (ValidationError ):
19- """ValidationError that records which qubit and gate caused the violation ."""
27+ """ValidationError for definite (Must) violations with concrete qubit addresses ."""
2028
2129 qubit_id : int
2230 gate_name : str
2331
2432 def __init__ (self , node : ir .IRNode , qubit_id : int , gate_name : str ):
25- # message stored in ValidationError so formatting/hint() will include it
2633 super ().__init__ (node , f"Qubit[{ qubit_id } ] cloned at { gate_name } gate." )
2734 self .qubit_id = qubit_id
2835 self .gate_name = gate_name
2936
3037
38+ class PotentialQubitValidationError (ValidationError ):
39+ """ValidationError for potential (May) violations with unknown addresses."""
40+
41+ gate_name : str
42+ condition : str
43+
44+ def __init__ (self , node : ir .IRNode , gate_name : str , condition : str ):
45+ super ().__init__ (node , f"Potential cloning at { gate_name } gate{ condition } ." )
46+ self .gate_name = gate_name
47+ self .condition = condition
48+
49+
3150class NoCloningValidation (Forward [QubitValidation ]):
3251 """
3352 Validates the no-cloning theorem by tracking qubit addresses.
@@ -40,9 +59,7 @@ class NoCloningValidation(Forward[QubitValidation]):
4059 _address_frame : ForwardFrame [Address ] = field (init = False )
4160 _type_frame : ForwardFrame = field (init = False )
4261 method : ir .Method
43- _validation_errors : list [QubitValidationError ] = field (
44- default_factory = list , init = False
45- )
62+ _validation_errors : list [ValidationError ] = field (default_factory = list , init = False )
4663
4764 def __init__ (self , mtd : ir .Method ):
4865 """
@@ -88,68 +105,126 @@ def eval_stmt_fallback(
88105 ) -> tuple [QubitValidation , ...]:
89106 """
90107 Default statement evaluation: check for qubit usage violations.
108+ Returns Bottom, May, Must, or Top depending on what we can prove.
91109 """
92110
93111 if not isinstance (stmt , func .Invoke ):
94- return tuple (QubitValidation . bottom () for _ in stmt .results )
112+ return tuple (Bottom () for _ in stmt .results )
95113
96114 address_frame = self ._address_frame
97115 if address_frame is None :
98- return tuple (QubitValidation . top () for _ in stmt .results )
116+ return tuple (Top () for _ in stmt .results )
99117
100- has_qubit_args = any (
101- isinstance ( address_frame . get ( arg ), ( AddressQubit , AddressReg ))
102- for arg in stmt . args
103- )
118+ concrete_addrs : list [ int ] = []
119+ has_unknown = False
120+ has_qubit_args = False
121+ unknown_arg_names : list [ str ] = []
104122
105- if not has_qubit_args :
106- return tuple (QubitValidation .bottom () for _ in stmt .results )
107-
108- used_addrs : list [int ] = []
109123 for arg in stmt .args :
110124 addr = address_frame .get (arg )
111- qubit_addrs = self .get_qubit_addresses (addr )
112- used_addrs .extend (qubit_addrs )
125+ match addr :
126+ case AddressQubit (data = qubit_addr ):
127+ has_qubit_args = True
128+ concrete_addrs .append (qubit_addr )
129+ case AddressReg (data = addrs ):
130+ has_qubit_args = True
131+ concrete_addrs .extend (addrs )
132+ case UnknownQubit () | UnknownReg () | Unknown ():
133+ has_qubit_args = True
134+ has_unknown = True
135+ arg_name = self ._get_source_name (arg )
136+ unknown_arg_names .append (arg_name )
137+ case _:
138+ pass
139+
140+ if not has_qubit_args :
141+ return tuple (Bottom () for _ in stmt .results )
113142
114143 seen : set [int ] = set ()
115- violations : list [str ] = []
144+ must_violations : list [str ] = []
145+ gate_name = stmt .callee .sym_name .upper ()
116146
117- for qubit_addr in used_addrs :
147+ for qubit_addr in concrete_addrs :
118148 if qubit_addr in seen :
119- gate_name = stmt . callee . sym_name . upper ( )
120- violations .append (self . format_violation ( qubit_addr , gate_name ) )
149+ violation = self . format_violation ( qubit_addr , gate_name )
150+ must_violations .append (violation )
121151 self ._validation_errors .append (
122152 QubitValidationError (stmt , qubit_addr , gate_name )
123153 )
124154 seen .add (qubit_addr )
125155
126- if not violations :
127- return tuple (QubitValidation (violations = frozenset ()) for _ in stmt .results )
156+ if must_violations :
157+ usage = Must (violations = frozenset (must_violations ))
158+ elif has_unknown :
159+ args_str = " == " .join (unknown_arg_names )
160+ if len (unknown_arg_names ) > 1 :
161+ condition = f", when { args_str } "
162+ else :
163+ condition = f", with unknown index { args_str } "
164+
165+ self ._validation_errors .append (
166+ PotentialQubitValidationError (stmt , gate_name , condition )
167+ )
168+
169+ usage = May (
170+ violations = frozenset ([f"Unknown qubits at { gate_name } Gate{ condition } " ])
171+ )
172+ else :
173+ usage = Bottom ()
128174
129- usage = QubitValidation (violations = frozenset (violations ))
130175 return tuple (usage for _ in stmt .results ) if stmt .results else (usage ,)
131176
177+ def _get_source_name (self , value : ir .SSAValue ) -> str :
178+ """Trace back to get the source variable name for a value.
179+
180+ For getitem operations like q[a], returns 'a'.
181+ For direct values, returns the value's name.
182+ """
183+ from kirin .dialects .py .indexing import GetItem
184+
185+ if isinstance (value , ir .ResultValue ) and isinstance (value .stmt , GetItem ):
186+ index_arg = value .stmt .args [1 ]
187+ return self ._get_source_name (index_arg )
188+
189+ if isinstance (value , ir .BlockArgument ):
190+ return value .name or f"arg{ value .index } "
191+
192+ if hasattr (value , "name" ) and value .name :
193+ return value .name
194+
195+ return str (value )
196+
132197 def run_method (
133198 self , method : ir .Method , args : tuple [QubitValidation , ...]
134199 ) -> tuple [ForwardFrame [QubitValidation ], QubitValidation ]:
135200 self_mt = self .method_self (method )
136201 return self .run_callable (method .code , (self_mt ,) + args )
137202
138203 def raise_validation_errors (self ):
139- """Raise validation error for each no-cloning violation found .
204+ """Raise validation errors for both definite and potential violations .
140205 Points to source file and line with snippet.
141206 """
142207 if not self ._validation_errors :
143208 return
144209
145- # If multiple errors, print all with snippets first
146- if len (self ._validation_errors ) > 1 :
147- for err in self ._validation_errors :
148- err .attach (self .method )
149- # Print error message before snippet
210+ # Print all errors with snippets
211+ for err in self ._validation_errors :
212+ err .attach (self .method )
213+
214+ # Format error message based on type
215+ if isinstance (err , QubitValidationError ):
150216 print (
151- f"\033 [31mValidation Error \033 [0m: Cloned qubit [{ err .qubit_id } ] at { err .gate_name } gate. "
217+ f"\n \ 033 [31mError \033 [0m: Cloning qubit [{ err .qubit_id } ] at { err .gate_name } gate"
152218 )
153- print (err .hint ())
154- print (f"Raised { len (self ._validation_errors )} error(s)." )
219+ elif isinstance (err , PotentialQubitValidationError ):
220+ print (
221+ f"\n \033 [33mWarning\033 [0m: Potential cloning at { err .gate_name } gate{ err .condition } "
222+ )
223+ else :
224+ print (
225+ f"\n \033 [31mError\033 [0m: { err .args [0 ] if err .args else type (err ).__name__ } "
226+ )
227+
228+ print (err .hint ())
229+
155230 raise
0 commit comments