From ab254f2bdaa918c2f8884c320cf7d48c08adb8ea Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Fri, 5 Dec 2025 16:25:36 +0100 Subject: [PATCH 1/2] Fix state initialization validation in gemini logical --- .../analysis/logical_validation/analysis.py | 36 ++++++++++++++++--- .../analysis/logical_validation/impls.py | 34 ++++++++++++++++-- test/gemini/test_logical_validation.py | 19 +++++++++- 3 files changed, 81 insertions(+), 8 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index 2b98672b..b4238dbf 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -1,27 +1,50 @@ from typing import Any -from dataclasses import dataclass +from dataclasses import field, dataclass from kirin import ir +from kirin.interp import InterpreterError from kirin.lattice import EmptyLattice from kirin.analysis import Forward, ForwardFrame from kirin.validation import ValidationPass from bloqade import squin +from bloqade.analysis.address import Address, AddressReg, AddressQubit, AddressAnalysis +@dataclass class _GeminiLogicalValidationAnalysis(Forward[EmptyLattice]): keys = ["gemini.validate.logical"] - first_gate = True lattice = EmptyLattice + addr_frame: ForwardFrame[Address] + + first_gates: dict[int, bool] = field(init=False, default_factory=dict) def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): if isinstance(node, squin.gate.stmts.Gate): - # NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here - self.first_gate = False + raise InterpreterError(f"Missing implementation for gate {node}") return tuple(self.lattice.bottom() for _ in range(len(node.results))) + def check_first_gate(self, qubits: ir.SSAValue) -> bool: + address = self.addr_frame.get(qubits) + + if isinstance(address, AddressQubit): + is_first = self.first_gates.get(address.data, True) + self.first_gates[address.data] = False + return is_first + elif isinstance(address, AddressReg): + is_first = True + for addr_int in address.data: + is_first = is_first and self.first_gates.get(addr_int, True) + self.first_gates[addr_int] = False + + return is_first + + # NOTE: we should have a flat kernel with simple address analysis, so in case we don't + # get concrete addresses, we might as well error here since something's wrong + return False + def method_self(self, method: ir.Method) -> EmptyLattice: return self.lattice.bottom() @@ -34,7 +57,10 @@ def name(self) -> str: return "Gemini Logical Validation" def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]: - analysis = _GeminiLogicalValidationAnalysis(method.dialects) + addr_frame, _ = AddressAnalysis(method.dialects).run(method) + analysis = _GeminiLogicalValidationAnalysis( + method.dialects, addr_frame=addr_frame + ) frame, _ = analysis.run(method) return frame, analysis.get_validation_errors() diff --git a/src/bloqade/gemini/analysis/logical_validation/impls.py b/src/bloqade/gemini/analysis/logical_validation/impls.py index c576ef0d..f915a6bc 100644 --- a/src/bloqade/gemini/analysis/logical_validation/impls.py +++ b/src/bloqade/gemini/analysis/logical_validation/impls.py @@ -80,8 +80,7 @@ def non_clifford( frame: ForwardFrame, stmt: gate.stmts.SingleQubitGate | gate.stmts.RotationGate, ): - if interp.first_gate: - interp.first_gate = False + if interp.check_first_gate(stmt.qubits): return () interp.add_validation_error( @@ -92,3 +91,34 @@ def non_clifford( ), ) return () + + @_interp.impl(gate.stmts.X) + @_interp.impl(gate.stmts.Y) + @_interp.impl(gate.stmts.SqrtX) + @_interp.impl(gate.stmts.SqrtY) + @_interp.impl(gate.stmts.Z) + @_interp.impl(gate.stmts.H) + @_interp.impl(gate.stmts.S) + def clifford( + self, + interp: _GeminiLogicalValidationAnalysis, + frame: ForwardFrame, + stmt: gate.stmts.SingleQubitGate, + ): + # NOTE: ignore result, but make sure the first gate flag is set to False + interp.check_first_gate(stmt.qubits) + + return () + + @_interp.impl(gate.stmts.CX) + @_interp.impl(gate.stmts.CY) + @_interp.impl(gate.stmts.CZ) + def controlled_gate( + self, + interp: _GeminiLogicalValidationAnalysis, + frame: ForwardFrame, + stmt: gate.stmts.ControlledGate, + ): + interp.check_first_gate(stmt.controls) + interp.check_first_gate(stmt.targets) + return () diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index 736d49fa..384b922d 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -4,6 +4,7 @@ from bloqade import squin, gemini from bloqade.types import Qubit +from bloqade.analysis.address import AddressAnalysis from bloqade.gemini.analysis.logical_validation.analysis import ( GeminiLogicalValidation, _GeminiLogicalValidationAnalysis, @@ -35,7 +36,10 @@ def main(): if m2: squin.y(q[2]) - frame, _ = _GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main) + addr_frame, _ = AddressAnalysis(main.dialects).run(main) + frame, _ = _GeminiLogicalValidationAnalysis( + main.dialects, addr_frame=addr_frame + ).run_no_raise(main) main.print(analysis=frame.entries) @@ -182,3 +186,16 @@ def main(n: int): assert len(e.errors) == 4 assert did_error + + +def test_non_clifford_parallel_gates(): + @gemini.logical.kernel + def main(): + q = squin.qalloc(5) + squin.rx(0.123, q[0]) + squin.broadcast.ry(0.333, q[1:]) + + squin.broadcast.x(q) + squin.broadcast.h(q[1:]) + + main.print() From b7fba430fa3c64d87e22a68fcc2625923c338e97 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 9 Dec 2025 09:22:02 +0100 Subject: [PATCH 2/2] Remove unreachable branch --- .../analysis/logical_validation/analysis.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/bloqade/gemini/analysis/logical_validation/analysis.py b/src/bloqade/gemini/analysis/logical_validation/analysis.py index b4238dbf..0195e011 100644 --- a/src/bloqade/gemini/analysis/logical_validation/analysis.py +++ b/src/bloqade/gemini/analysis/logical_validation/analysis.py @@ -8,7 +8,7 @@ from kirin.validation import ValidationPass from bloqade import squin -from bloqade.analysis.address import Address, AddressReg, AddressQubit, AddressAnalysis +from bloqade.analysis.address import Address, AddressReg, AddressAnalysis @dataclass @@ -29,21 +29,17 @@ def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): def check_first_gate(self, qubits: ir.SSAValue) -> bool: address = self.addr_frame.get(qubits) - if isinstance(address, AddressQubit): - is_first = self.first_gates.get(address.data, True) - self.first_gates[address.data] = False - return is_first - elif isinstance(address, AddressReg): - is_first = True - for addr_int in address.data: - is_first = is_first and self.first_gates.get(addr_int, True) - self.first_gates[addr_int] = False - - return is_first - - # NOTE: we should have a flat kernel with simple address analysis, so in case we don't - # get concrete addresses, we might as well error here since something's wrong - return False + if not isinstance(address, AddressReg): + # NOTE: we should have a flat kernel with simple address analysis, so in case we don't + # get concrete addresses, we might as well error here since something's wrong + return False + + is_first = True + for addr_int in address.data: + is_first = is_first and self.first_gates.get(addr_int, True) + self.first_gates[addr_int] = False + + return is_first def method_self(self, method: ir.Method) -> EmptyLattice: return self.lattice.bottom()