Skip to content

Commit 64dd86a

Browse files
committed
Implement no-cloning validation
1 parent 1642980 commit 64dd86a

File tree

6 files changed

+331
-0
lines changed

6 files changed

+331
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import impls as impls
2+
from .analysis import NoCloningValidation as NoCloningValidation
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from dataclasses import field
2+
3+
from kirin import ir
4+
from kirin.analysis import Forward, TypeInference
5+
from kirin.dialects import func
6+
from kirin.analysis.forward import ForwardFrame
7+
8+
from bloqade.analysis.address import (
9+
Address,
10+
AddressReg,
11+
AddressQubit,
12+
AddressAnalysis,
13+
)
14+
from bloqade.analysis.address.lattice import QubitLike
15+
16+
from .lattice import QubitValidation
17+
18+
19+
class NoCloningValidation(Forward[QubitValidation]):
20+
"""
21+
Validates the no-cloning theorem by tracking qubit addresses.
22+
23+
Built on top of AddressAnalysis to get qubit address information.
24+
"""
25+
26+
keys = ["validate.nocloning"]
27+
lattice = QubitValidation
28+
_address_frame: ForwardFrame[Address] = field(init=False)
29+
_type_frame: ForwardFrame = field(init=False)
30+
method: ir.Method
31+
violations: int = field(default=0, init=False)
32+
33+
def __init__(self, mtd: ir.Method):
34+
"""
35+
Input:
36+
- an ir.Method / kernel function
37+
infer dialects from it and remember method.
38+
"""
39+
self.method = mtd
40+
super().__init__(mtd.dialects)
41+
42+
def initialize(self):
43+
super().initialize()
44+
45+
address_analysis = AddressAnalysis(self.dialects)
46+
address_analysis.initialize()
47+
self._address_frame, _ = address_analysis.run_analysis(self.method)
48+
49+
type_inference = TypeInference(self.dialects)
50+
type_inference.initialize()
51+
self._type_frame, _ = type_inference.run_analysis(self.method)
52+
53+
return self
54+
55+
def method_self(self, method: ir.Method) -> QubitValidation:
56+
return self.lattice.bottom()
57+
58+
def get_qubit_addresses(self, addr: Address) -> frozenset[int]:
59+
"""Extract concrete qubit addresses from an Address lattice element."""
60+
match addr:
61+
case AddressQubit(data=qubit_addr):
62+
return frozenset([qubit_addr])
63+
case AddressReg(data=addrs):
64+
return frozenset(addrs)
65+
case _:
66+
return frozenset()
67+
68+
def get_stmt_info(self, stmt: ir.Statement) -> str:
69+
"""String Report about the statement for violation messages."""
70+
if isinstance(stmt, func.Invoke) and hasattr(stmt, "callee"):
71+
gate_name = stmt.callee.sym_name.upper()
72+
return f"{gate_name} Gate"
73+
74+
return f"{stmt.__class__.__name__}@{stmt}"
75+
76+
def eval_stmt_fallback(
77+
self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement
78+
) -> tuple[QubitValidation, ...]:
79+
"""
80+
Default statement evaluation: check for qubit usage violations.
81+
"""
82+
83+
if not isinstance(stmt, func.Invoke):
84+
return tuple(QubitValidation.bottom() for _ in stmt.results)
85+
86+
address_frame = self._address_frame
87+
if address_frame is None:
88+
return tuple(QubitValidation.top() for _ in stmt.results)
89+
90+
has_qubit_args = any(
91+
isinstance(address_frame.get(arg), QubitLike) for arg in stmt.args
92+
)
93+
94+
if not has_qubit_args:
95+
return tuple(QubitValidation.bottom() for _ in stmt.results)
96+
97+
used_addrs: list[int] = []
98+
for arg in stmt.args:
99+
addr = address_frame.get(arg)
100+
qubit_addrs = self.get_qubit_addresses(addr)
101+
used_addrs.extend(qubit_addrs)
102+
103+
seen: set[int] = set()
104+
violations: list[str] = []
105+
stmt_info = self.get_stmt_info(stmt)
106+
107+
for qubit_addr in used_addrs:
108+
if qubit_addr in seen:
109+
violations.append(f"Qubit[{qubit_addr}] at {stmt_info}")
110+
seen.add(qubit_addr)
111+
112+
if not violations:
113+
return tuple(QubitValidation(violations=frozenset()) for _ in stmt.results)
114+
115+
usage = QubitValidation(violations=frozenset(violations))
116+
return tuple(usage for _ in stmt.results) if stmt.results else (usage,)
117+
118+
def run_method(
119+
self, method: ir.Method, args: tuple[QubitValidation, ...]
120+
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
121+
self_mt = self.method_self(method)
122+
return self.run_callable(method.code, (self_mt,) + args)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from kirin import interp
2+
from kirin.analysis import ForwardFrame
3+
from kirin.dialects import scf
4+
5+
from .lattice import QubitValidation
6+
from .analysis import NoCloningValidation
7+
8+
9+
@scf.dialect.register(key="validate.nocloning")
10+
class Scf(interp.MethodTable):
11+
@interp.impl(scf.IfElse)
12+
def if_else(
13+
self,
14+
interp_: NoCloningValidation,
15+
frame: ForwardFrame[QubitValidation],
16+
stmt: scf.IfElse,
17+
):
18+
cond_validation = frame.get(stmt.cond)
19+
20+
then_results = interp_.run_callable_region(
21+
frame, stmt, stmt.then_body, (cond_validation,)
22+
)
23+
24+
if stmt.else_body:
25+
else_results = interp_.run_callable_region(
26+
frame, stmt, stmt.else_body, (cond_validation,)
27+
)
28+
29+
merged = tuple(then_results.join(else_results) for _ in stmt.results)
30+
else:
31+
merged = tuple(then_results for _ in stmt.results)
32+
33+
return merged if merged else (QubitValidation.bottom(),)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import FrozenSet, final
2+
from dataclasses import field, dataclass
3+
4+
from kirin.lattice import (
5+
SingletonMeta,
6+
BoundedLattice,
7+
SimpleJoinMixin,
8+
SimpleMeetMixin,
9+
)
10+
11+
12+
@dataclass
13+
class QubitValidation(
14+
SimpleJoinMixin["QubitValidation"],
15+
SimpleMeetMixin["QubitValidation"],
16+
BoundedLattice["QubitValidation"],
17+
):
18+
"""Tracks cloning violations detected during analysis."""
19+
20+
violations: FrozenSet[str] = field(default_factory=frozenset)
21+
22+
@classmethod
23+
def bottom(cls) -> "QubitValidation":
24+
"""No violations detected"""
25+
return Bottom()
26+
27+
@classmethod
28+
def top(cls) -> "QubitValidation":
29+
"""Unknown state - assume potential violations"""
30+
return Top()
31+
32+
def is_subseteq(self, other: "QubitValidation") -> bool:
33+
"""Check if this state is a subset of another.
34+
35+
Lattice ordering:
36+
Bottom ⊑ {{'Qubit[1] at CX Gate'}} ⊑ {{'Qubit[0] at CX Gate'},{'Qubit[1] at CX Gate'}} ⊑ Top
37+
"""
38+
if isinstance(other, Top):
39+
return True
40+
if isinstance(self, Bottom):
41+
return True
42+
if isinstance(other, Bottom):
43+
return False
44+
45+
return self.violations.issubset(other.violations)
46+
47+
def __repr__(self) -> str:
48+
"""Custom repr to show violations clearly."""
49+
if not self.violations:
50+
return "QubitValidation()"
51+
return f"QubitValidation(violations={self.violations})"
52+
53+
54+
@final
55+
class Bottom(QubitValidation, metaclass=SingletonMeta):
56+
"""Bottom element representing no violations."""
57+
58+
def is_subseteq(self, other: QubitValidation) -> bool:
59+
"""Bottom is subset of everything."""
60+
return True
61+
62+
def __repr__(self) -> str:
63+
"""Cleaner printing."""
64+
return "⊥ (Bottom)"
65+
66+
67+
@final
68+
class Top(QubitValidation, metaclass=SingletonMeta):
69+
"""Top element representing unknown state with potential violations."""
70+
71+
def is_subseteq(self, other: QubitValidation) -> bool:
72+
"""Top is only subset of Top."""
73+
return isinstance(other, Top)
74+
75+
def __repr__(self) -> str:
76+
"""Cleaner printing."""
77+
return "⊤ (Top)"
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Any
2+
3+
import pytest
4+
from util import collect_validation_errors
5+
from kirin import ir
6+
from kirin.dialects.ilist.runtime import IList
7+
8+
from bloqade import squin
9+
from bloqade.types import Qubit
10+
from bloqade.analysis.validation.nocloning.lattice import QubitValidation
11+
from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation
12+
13+
14+
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
15+
def test_control_gate_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
16+
@squin.kernel
17+
def bad_control():
18+
q = squin.qalloc(1)
19+
control_gate(q[0], q[0])
20+
21+
validation = NoCloningValidation(bad_control)
22+
validation.initialize()
23+
frame, _ = validation.run_analysis(bad_control)
24+
print()
25+
bad_control.print(analysis=frame.entries)
26+
validation_errors = collect_validation_errors(frame, QubitValidation)
27+
assert len(validation_errors) == 1
28+
29+
30+
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
31+
def test_control_gate_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
32+
@squin.kernel
33+
def bad_control(cond: bool):
34+
q = squin.qalloc(10)
35+
if cond:
36+
control_gate(q[0], q[0])
37+
else:
38+
control_gate(q[0], q[1])
39+
squin.cx(q[1], q[1])
40+
41+
validation = NoCloningValidation(bad_control)
42+
validation.initialize()
43+
frame, _ = validation.run_analysis(bad_control)
44+
print()
45+
bad_control.print(analysis=frame.entries)
46+
validation_errors = collect_validation_errors(frame, QubitValidation)
47+
# print("Violations:", validation_errors)
48+
assert len(validation_errors) == 2
49+
50+
51+
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
52+
def test_control_gate_parallel_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
53+
@squin.kernel
54+
def bad_control():
55+
q = squin.qalloc(2)
56+
control_gate(q[0], q[1])
57+
58+
validation = NoCloningValidation(bad_control)
59+
validation.initialize()
60+
frame, _ = validation.run_analysis(bad_control)
61+
print()
62+
bad_control.print(analysis=frame.entries)
63+
validation_errors = collect_validation_errors(frame, QubitValidation)
64+
assert len(validation_errors) == 0
65+
66+
67+
def test_control_gate_parallel_pass():
68+
@squin.kernel
69+
def good_kernel():
70+
q = squin.qalloc(2)
71+
squin.cx(q[0], q[1])
72+
squin.cy(q[1], q[1])
73+
74+
validation = NoCloningValidation(good_kernel)
75+
validation.initialize()
76+
frame, _ = validation.run_analysis(good_kernel)
77+
print()
78+
good_kernel.print(analysis=frame.entries)
79+
validation_errors = collect_validation_errors(frame, QubitValidation)
80+
assert len(validation_errors) == 1

test/analysis/validation/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import TypeVar
2+
3+
from kirin.analysis import ForwardFrame
4+
5+
from bloqade.analysis.validation.nocloning.lattice import QubitValidation
6+
7+
T = TypeVar("T", bound=QubitValidation)
8+
9+
10+
def collect_validation_errors(
11+
frame: ForwardFrame[QubitValidation], typ: type[T]
12+
) -> list[T]:
13+
return [
14+
validation_errors
15+
for validation_errors in frame.entries.values()
16+
if isinstance(validation_errors, typ) and len(validation_errors.violations) > 0
17+
]

0 commit comments

Comments
 (0)