Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/bloqade/gemini/analysis/logical_validation/analysis.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,50 @@
from typing import Any
from dataclasses import dataclass
from dataclasses import field, dataclass

from kirin import ir
from kirin.interp import InterpreterError
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward, ForwardFrame
from kirin.validation import ValidationPass

from bloqade import squin
from bloqade.analysis.address import Address, AddressReg, AddressQubit, AddressAnalysis


@dataclass
class _GeminiLogicalValidationAnalysis(Forward[EmptyLattice]):
keys = ["gemini.validate.logical"]

first_gate = True
lattice = EmptyLattice
addr_frame: ForwardFrame[Address]

first_gates: dict[int, bool] = field(init=False, default_factory=dict)

def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
if isinstance(node, squin.gate.stmts.Gate):
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
self.first_gate = False
raise InterpreterError(f"Missing implementation for gate {node}")

return tuple(self.lattice.bottom() for _ in range(len(node.results)))

def check_first_gate(self, qubits: ir.SSAValue) -> bool:
address = self.addr_frame.get(qubits)

if isinstance(address, AddressQubit):
is_first = self.first_gates.get(address.data, True)
self.first_gates[address.data] = False
return is_first
elif isinstance(address, AddressReg):
is_first = True
for addr_int in address.data:
is_first = is_first and self.first_gates.get(addr_int, True)
self.first_gates[addr_int] = False

return is_first

# NOTE: we should have a flat kernel with simple address analysis, so in case we don't
# get concrete addresses, we might as well error here since something's wrong
return False

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()

Expand All @@ -34,7 +57,10 @@ def name(self) -> str:
return "Gemini Logical Validation"

def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
analysis = _GeminiLogicalValidationAnalysis(method.dialects)
addr_frame, _ = AddressAnalysis(method.dialects).run(method)
analysis = _GeminiLogicalValidationAnalysis(
method.dialects, addr_frame=addr_frame
)
frame, _ = analysis.run(method)

return frame, analysis.get_validation_errors()
34 changes: 32 additions & 2 deletions src/bloqade/gemini/analysis/logical_validation/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def non_clifford(
frame: ForwardFrame,
stmt: gate.stmts.SingleQubitGate | gate.stmts.RotationGate,
):
if interp.first_gate:
interp.first_gate = False
if interp.check_first_gate(stmt.qubits):
return ()

interp.add_validation_error(
Expand All @@ -92,3 +91,34 @@ def non_clifford(
),
)
return ()

@_interp.impl(gate.stmts.X)
@_interp.impl(gate.stmts.Y)
@_interp.impl(gate.stmts.SqrtX)
@_interp.impl(gate.stmts.SqrtY)
@_interp.impl(gate.stmts.Z)
@_interp.impl(gate.stmts.H)
@_interp.impl(gate.stmts.S)
def clifford(
self,
interp: _GeminiLogicalValidationAnalysis,
frame: ForwardFrame,
stmt: gate.stmts.SingleQubitGate,
):
# NOTE: ignore result, but make sure the first gate flag is set to False
interp.check_first_gate(stmt.qubits)

return ()

@_interp.impl(gate.stmts.CX)
@_interp.impl(gate.stmts.CY)
@_interp.impl(gate.stmts.CZ)
def controlled_gate(
self,
interp: _GeminiLogicalValidationAnalysis,
frame: ForwardFrame,
stmt: gate.stmts.ControlledGate,
):
interp.check_first_gate(stmt.controls)
interp.check_first_gate(stmt.targets)
return ()
19 changes: 18 additions & 1 deletion test/gemini/test_logical_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from bloqade import squin, gemini
from bloqade.types import Qubit
from bloqade.analysis.address import AddressAnalysis
from bloqade.gemini.analysis.logical_validation.analysis import (
GeminiLogicalValidation,
_GeminiLogicalValidationAnalysis,
Expand Down Expand Up @@ -35,7 +36,10 @@ def main():
if m2:
squin.y(q[2])

frame, _ = _GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main)
addr_frame, _ = AddressAnalysis(main.dialects).run(main)
frame, _ = _GeminiLogicalValidationAnalysis(
main.dialects, addr_frame=addr_frame
).run_no_raise(main)

main.print(analysis=frame.entries)

Expand Down Expand Up @@ -182,3 +186,16 @@ def main(n: int):
assert len(e.errors) == 4

assert did_error


def test_non_clifford_parallel_gates():
@gemini.logical.kernel
def main():
q = squin.qalloc(5)
squin.rx(0.123, q[0])
squin.broadcast.ry(0.333, q[1:])

squin.broadcast.x(q)
squin.broadcast.h(q[1:])

main.print()
Loading