1- from dataclasses import field
1+ from typing import Any
22
33from kirin import ir
4- from kirin .analysis import Forward , TypeInference
4+ from kirin .analysis import Forward
55from kirin .dialects import func
66from kirin .ir .exception import ValidationError
77from kirin .analysis .forward import ForwardFrame
1515 AddressReg ,
1616 UnknownReg ,
1717 AddressQubit ,
18+ PartialIList ,
19+ PartialTuple ,
1820 UnknownQubit ,
1921)
22+ from bloqade .analysis .validation .validationpass import ValidationPass
2023
2124from .lattice import May , Top , Must , Bottom , QubitValidation
2225
@@ -45,66 +48,39 @@ def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
4548 self .condition = condition
4649
4750
48- class NoCloningValidation (Forward [QubitValidation ]):
49- """
50- Validates the no-cloning theorem by tracking qubit addresses.
51-
52- Built on top of AddressAnalysis to get qubit address information.
53- """
51+ class _NoCloningAnalysis (Forward [QubitValidation ]):
52+ """Internal forward analysis for tracking qubit cloning violations."""
5453
5554 keys = ["validate.nocloning" ]
5655 lattice = QubitValidation
57- _address_frame : ForwardFrame [Address ] = field (init = False )
58- _type_frame : ForwardFrame = field (init = False )
59- method : ir .Method
60- _validation_errors : list [ValidationError ] = field (default_factory = list , init = False )
6156
62- def __init__ (self , mtd : ir .Method ):
63- """
64- Input:
65- - an ir.Method / kernel function
66- infer dialects from it and remember method.
67- """
68- self .method = mtd
69- super ().__init__ (mtd .dialects )
57+ def __init__ (self , dialects ):
58+ super ().__init__ (dialects )
59+ self ._address_frame : ForwardFrame [Address ] | None = None
60+ self ._validation_errors : list [ValidationError ] = []
7061
7162 def initialize (self ):
7263 super ().initialize ()
7364 self ._validation_errors = []
74- address_analysis = AddressAnalysis (self .dialects )
75- address_analysis .initialize ()
76- self ._address_frame , _ = address_analysis .run_analysis (self .method )
77-
78- type_inference = TypeInference (self .dialects )
79- type_inference .initialize ()
80- self ._type_frame , _ = type_inference .run_analysis (self .method )
81-
8265 return self
8366
84- def method_self (self , method : ir .Method ) -> QubitValidation :
85- return self .lattice .bottom ()
67+ def run_method (
68+ self , method : ir .Method , args : tuple [QubitValidation , ...]
69+ ) -> tuple [ForwardFrame [QubitValidation ], QubitValidation ]:
70+ if self ._address_frame is None :
71+ if getattr (self , "_address_analysis" , None ) is None :
72+ addr_analysis = AddressAnalysis (self .dialects )
73+ addr_analysis .initialize ()
74+ self ._address_analysis = addr_analysis
8675
87- def get_qubit_addresses (self , addr : Address ) -> frozenset [int ]:
88- """Extract concrete qubit addresses from an Address lattice element."""
89- match addr :
90- case AddressQubit (data = qubit_addr ):
91- return frozenset ([qubit_addr ])
92- case AddressReg (data = addrs ):
93- return frozenset (addrs )
94- case _:
95- return frozenset ()
76+ self ._address_frame , _ = self ._address_analysis .run_analysis (method )
9677
97- def format_violation (self , qubit_id : int , gate_name : str ) -> str :
98- """Return the violation string for a qubit + gate."""
99- return f"Qubit[{ qubit_id } ] on { gate_name } Gate"
78+ return self .run_callable (method .code , args )
10079
10180 def eval_stmt_fallback (
10281 self , frame : ForwardFrame [QubitValidation ], stmt : ir .Statement
10382 ) -> tuple [QubitValidation , ...]:
104- """
105- Default statement evaluation: check for qubit usage violations.
106- Returns Bottom, May, Must, or Top depending on what we can prove.
107- """
83+ """Check for qubit usage violations."""
10884
10985 if not isinstance (stmt , func .Invoke ):
11086 return tuple (Bottom () for _ in stmt .results )
@@ -127,7 +103,13 @@ def eval_stmt_fallback(
127103 case AddressReg (data = addrs ):
128104 has_qubit_args = True
129105 concrete_addrs .extend (addrs )
130- case UnknownQubit () | UnknownReg () | Unknown ():
106+ case (
107+ UnknownQubit ()
108+ | UnknownReg ()
109+ | PartialIList ()
110+ | PartialTuple ()
111+ | Unknown ()
112+ ):
131113 has_qubit_args = True
132114 has_unknown = True
133115 arg_name = self ._get_source_name (arg )
@@ -144,7 +126,7 @@ def eval_stmt_fallback(
144126
145127 for qubit_addr in concrete_addrs :
146128 if qubit_addr in seen :
147- violation = self . format_violation ( qubit_addr , gate_name )
129+ violation = f"Qubit[ { qubit_addr } ] on { gate_name } Gate"
148130 must_violations .append (violation )
149131 self ._validation_errors .append (
150132 QubitValidationError (stmt , qubit_addr , gate_name )
@@ -171,11 +153,7 @@ def eval_stmt_fallback(
171153 return tuple (usage for _ in stmt .results ) if stmt .results else (usage ,)
172154
173155 def _get_source_name (self , value : ir .SSAValue ) -> str :
174- """Trace back to get the source variable name for a value.
175-
176- For getitem operations like q[a], returns 'a'.
177- For direct values, returns the value's name.
178- """
156+ """Trace back to get the source variable name."""
179157 from kirin .dialects .py .indexing import GetItem
180158
181159 if isinstance (value , ir .ResultValue ) and isinstance (value .stmt , GetItem ):
@@ -190,24 +168,52 @@ def _get_source_name(self, value: ir.SSAValue) -> str:
190168
191169 return str (value )
192170
193- def run_method (
194- self , method : ir .Method , args : tuple [QubitValidation , ...]
195- ) -> tuple [ForwardFrame [QubitValidation ], QubitValidation ]:
196- self_mt = self .method_self (method )
197- return self .run_callable (method .code , (self_mt ,) + args )
198171
199- def raise_validation_errors (self ):
200- """Raise validation errors for both definite and potential violations.
201- Points to source file and line with snippet.
172+ class NoCloningValidation (ValidationPass ):
173+ """Validates the no-cloning theorem by tracking qubit addresses."""
174+
175+ def __init__ (self ):
176+ self .method : ir .Method | None = None
177+ self ._analysis : _NoCloningAnalysis | None = None
178+ self ._cached_address_frame = None
179+
180+ def name (self ) -> str :
181+ return "No-Cloning Validation"
182+
183+ def get_required_analyses (self ) -> list [type ]:
184+ """Declare dependency on AddressAnalysis."""
185+ return [AddressAnalysis ]
186+
187+ def set_analysis_cache (self , cache : dict [type , Any ]) -> None :
188+ """Use cached AddressAnalysis result."""
189+ self ._cached_address_frame = cache .get (AddressAnalysis )
190+
191+ def run (self , method : ir .Method ) -> tuple [Any , list [ValidationError ]]:
192+ """Run the no-cloning validation analysis.
193+
194+ Returns:
195+ - frame: ForwardFrame with QubitValidation lattice values
196+ - errors: List of validation errors found
202197 """
203- if not self ._validation_errors :
204- return
198+ if self ._analysis is None :
199+ self ._analysis = _NoCloningAnalysis (method .dialects )
200+
201+ self .method = method
202+ self ._analysis .initialize ()
203+ if self ._cached_address_frame is not None :
204+ self ._analysis ._address_frame = self ._cached_address_frame
205+ frame , _ = self ._analysis .run_analysis (method , args = None )
205206
206- # Print all errors with snippets
207- for err in self ._validation_errors :
208- err .attach (self .method )
207+ return frame , self ._analysis ._validation_errors
209208
210- # Format error message based on type
209+ def print_validation_errors (self ):
210+ """Print all collected errors with formatted snippets."""
211+ if self ._analysis is None :
212+ return
213+ errors = self ._analysis ._validation_errors
214+ if not errors :
215+ return
216+ for err in errors :
211217 if isinstance (err , QubitValidationError ):
212218 print (
213219 f"\n \033 [31mError\033 [0m: Cloning qubit [{ err .qubit_id } ] at { err .gate_name } gate"
@@ -220,7 +226,4 @@ def raise_validation_errors(self):
220226 print (
221227 f"\n \033 [31mError\033 [0m: { err .args [0 ] if err .args else type (err ).__name__ } "
222228 )
223-
224229 print (err .hint ())
225-
226- raise
0 commit comments