11from typing import Any
2- from dataclasses import dataclass
2+ from dataclasses import field , dataclass
33
44from kirin import ir
5+ from kirin .interp import InterpreterError
56from kirin .lattice import EmptyLattice
67from kirin .analysis import Forward , ForwardFrame
78from kirin .validation import ValidationPass
89
910from bloqade import squin
11+ from bloqade .analysis .address import Address , AddressReg , AddressQubit , AddressAnalysis
1012
1113
14+ @dataclass
1215class _GeminiLogicalValidationAnalysis (Forward [EmptyLattice ]):
1316 keys = ["gemini.validate.logical" ]
1417
15- first_gate = True
1618 lattice = EmptyLattice
19+ addr_frame : ForwardFrame [Address ]
20+
21+ first_gates : dict [int , bool ] = field (init = False , default_factory = dict )
1722
1823 def eval_fallback (self , frame : ForwardFrame , node : ir .Statement ):
1924 if isinstance (node , squin .gate .stmts .Gate ):
20- # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
21- self .first_gate = False
25+ raise InterpreterError (f"Missing implementation for gate { node } " )
2226
2327 return tuple (self .lattice .bottom () for _ in range (len (node .results )))
2428
29+ def check_first_gate (self , qubits : ir .SSAValue ) -> bool :
30+ address = self .addr_frame .get (qubits )
31+
32+ if isinstance (address , AddressQubit ):
33+ is_first = self .first_gates .get (address .data , True )
34+ self .first_gates [address .data ] = False
35+ return is_first
36+ elif isinstance (address , AddressReg ):
37+ is_first = True
38+ for addr_int in address .data :
39+ is_first = is_first and self .first_gates .get (addr_int , True )
40+ self .first_gates [addr_int ] = False
41+
42+ return is_first
43+
44+ # NOTE: we should have a flat kernel with simple address analysis, so in case we don't
45+ # get concrete addresses, we might as well error here since something's wrong
46+ return False
47+
2548 def method_self (self , method : ir .Method ) -> EmptyLattice :
2649 return self .lattice .bottom ()
2750
@@ -34,7 +57,10 @@ def name(self) -> str:
3457 return "Gemini Logical Validation"
3558
3659 def run (self , method : ir .Method ) -> tuple [Any , list [ir .ValidationError ]]:
37- analysis = _GeminiLogicalValidationAnalysis (method .dialects )
60+ addr_frame , _ = AddressAnalysis (method .dialects ).run (method )
61+ analysis = _GeminiLogicalValidationAnalysis (
62+ method .dialects , addr_frame = addr_frame
63+ )
3864 frame , _ = analysis .run (method )
3965
4066 return frame , analysis .get_validation_errors ()
0 commit comments