diff --git a/src/bloqade/squin/analysis/nsites/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py new file mode 100644 index 00000000..da0a8e86 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -0,0 +1,8 @@ +# Need this for impl registration to work properly! +from . import impls as impls +from .lattice import ( + NoSites as NoSites, + AnySites as AnySites, + NumberSites as NumberSites, +) +from .analysis import NSitesAnalysis as NSitesAnalysis diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py new file mode 100644 index 00000000..2c4cd054 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -0,0 +1,52 @@ +# from typing import cast + +from kirin import ir +from kirin.analysis import Forward +from kirin.analysis.forward import ForwardFrame + +from bloqade.squin.op.types import OpType +from bloqade.squin.op.traits import HasSites, FixedSites + +from .lattice import Sites, NoSites, NumberSites + + +class NSitesAnalysis(Forward[Sites]): + + keys = ["op.nsites"] + lattice = Sites + + # Take a page from const prop in Kirin, + # I can get the data I want from the SizedTrait + # and go from there + + ## This gets called before the registry look up + def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): + method = self.lookup_registry(frame, stmt) + if method is not None: + return method(self, frame, stmt) + elif stmt.has_trait(HasSites): + has_sites_trait = stmt.get_trait(HasSites) + sites = has_sites_trait.get_sites(stmt) + return (NumberSites(sites=sites),) + elif stmt.has_trait(FixedSites): + sites_trait = stmt.get_trait(FixedSites) + return (NumberSites(sites=sites_trait.data),) + else: + return (NoSites(),) + + # For when no implementation is found for the statement + def eval_stmt_fallback( + self, frame: ForwardFrame[Sites], stmt: ir.Statement + ) -> tuple[Sites, ...]: # some form of Shape will go back into the frame + return tuple( + ( + self.lattice.top() + if result.type.is_subseteq(OpType) + else self.lattice.bottom() + ) + for result in stmt.results + ) + + def run_method(self, method: ir.Method, args: tuple[Sites, ...]): + # NOTE: we do not support dynamic calls here, thus no need to propagate method object + return self.run_callable(method.code, (self.lattice.bottom(),) + args) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py new file mode 100644 index 00000000..3a4f94f1 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -0,0 +1,69 @@ +from typing import cast + +from kirin import ir, interp + +from bloqade.squin import op + +from .lattice import ( + NoSites, + NumberSites, +) +from .analysis import NSitesAnalysis + + +@op.dialect.register(key="op.nsites") +class SquinOp(interp.MethodTable): + + @interp.impl(op.stmts.Kron) + def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): + new_n_sites = lhs.sites + rhs.sites + return (NumberSites(sites=new_n_sites),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Mult) + def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + + if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): + lhs_sites = lhs.sites + rhs_sites = rhs.sites + # I originally considered throwing an exception here + # but Xiu-zhe (Roger) Luo has pointed out it would be + # a much better UX to add a type element that + # could explicitly indicate the error. The downside + # is you'll have some added complexity in the type lattice. + if lhs_sites != rhs_sites: + return (NoSites(),) + else: + return (NumberSites(sites=lhs_sites + rhs_sites),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Control) + def control( + self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control + ): + op_sites = frame.get(stmt.op) + + if isinstance(op_sites, NumberSites): + n_sites = op_sites.sites + n_controls_attr = stmt.get_attr_or_prop("n_controls") + n_controls = cast(ir.PyAttr[int], n_controls_attr).data + return (NumberSites(sites=n_sites + n_controls),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Rot) + def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): + op_sites = frame.get(stmt.axis) + return (op_sites,) + + @interp.impl(op.stmts.Scale) + def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): + op_sites = frame.get(stmt.op) + return (op_sites,) diff --git a/src/bloqade/squin/analysis/nsites/lattice.py b/src/bloqade/squin/analysis/nsites/lattice.py new file mode 100644 index 00000000..cf11ef72 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/lattice.py @@ -0,0 +1,49 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class Sites( + SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"] +): + @classmethod + def bottom(cls) -> "Sites": + return NoSites() + + @classmethod + def top(cls) -> "Sites": + return AnySites() + + +@final +@dataclass +class NoSites(Sites, metaclass=SingletonMeta): + + def is_subseteq(self, other: Sites) -> bool: + return True + + +@final +@dataclass +class AnySites(Sites, metaclass=SingletonMeta): + + def is_subseteq(self, other: Sites) -> bool: + return isinstance(other, Sites) + + +@final +@dataclass +class NumberSites(Sites): + sites: int + + def is_subseteq(self, other: Sites) -> bool: + if isinstance(other, NumberSites): + return self.sites == other.sites + return False diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 09fb6052..dcffd93a 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .traits import Sized, HasSize, Unitary, MaybeUnitary +from .traits import Unitary, HasSites, FixedSites, MaybeUnitary from .complex import Complex from ._dialect import dialect @@ -77,21 +77,23 @@ class Rot(CompositeOp): @statement(dialect=dialect) class Identity(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()}) - size: int = info.attribute() + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSites()}) + sites: int = info.attribute() result: ir.ResultValue = info.result(OpType) @statement class ConstantOp(PrimitiveOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)}) + traits = frozenset( + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), FixedSites(1)} + ) result: ir.ResultValue = info.result(OpType) @statement class ConstantUnitary(ConstantOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)} + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), FixedSites(1)} ) @@ -105,7 +107,7 @@ class PhaseOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -120,7 +122,7 @@ class ShiftOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index 28ea65f0..506fdbf2 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -5,22 +5,22 @@ @dataclass(frozen=True) -class Sized(ir.StmtTrait): +class FixedSites(ir.StmtTrait): data: int @dataclass(frozen=True) -class HasSize(ir.StmtTrait): - """An operator with a `size` attribute.""" +class HasSites(ir.StmtTrait): + """An operator with a `sites` attribute.""" - def get_size(self, stmt: ir.Statement): - attr = stmt.get_attr_or_prop("size") + def get_sites(self, stmt: ir.Statement): + attr = stmt.get_attr_or_prop("sites") if attr is None: - raise ValueError(f"Missing size attribute in {stmt}") + raise ValueError(f"Missing sites attribute in {stmt}") return cast(ir.PyAttr[int], attr).data - def set_size(self, stmt: ir.Statement, value: int): - stmt.attributes["size"] = ir.PyAttr(value) + def set_sites(self, stmt: ir.Statement, value: int): + stmt.attributes["sites"] = ir.PyAttr(value) return diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py new file mode 100644 index 00000000..19caa3e0 --- /dev/null +++ b/test/squin/analysis/test_nsites_analysis.py @@ -0,0 +1,256 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func + +from bloqade import squin +from bloqade.squin.analysis import nsites + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_primitive_ops(): + # test a couple standard operators derived from PrimitiveOp + + stmts = [ + (n_qubits := as_int(1)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q := py.GetItem(qreg.result, idx0.result)), + # get wire + (w := squin.wire.Unwrap(q.result)), + # put wire through gates + (h := squin.op.stmts.H()), + (t := squin.op.stmts.T()), + (x := squin.op.stmts.X()), + (v0 := squin.wire.Apply(h.result, w.result)), + (v1 := squin.wire.Apply(t.result, v0.results[0])), + (v2 := squin.wire.Apply(x.result, v1.results[0])), + (func.Return(v2.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + assert nsites_type.sites == 1 + + assert len(has_n_sites) == 3 + + +# Kron, Mult, Control, Rot, and Scale all have methods defined for handling them in impls, +# The following should ensure the code paths are properly exercised + + +def test_control(): + # Control doesn't have an impl but it is handled in the eval_stmt of the interpreter + # because it has a HasNSitesTrait future statements might have + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (controlled_h := squin.op.stmts.Control(op=h0.result, n_controls=1)), + (func.Return(controlled_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 2 + + +def test_kron(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + (hh := squin.op.stmts.Kron(h0.result, h1.result)), + (func.Return(hh.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 3 + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 2 + + +def test_mult_square_same_sites(): + # Ensure that two operators of the same size produce + # a valid operator as their result + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + (h2 := squin.op.stmts.Mult(h0.result, h1.result)), + (func.Return(h2.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + + # should be three HasNSites types + assert len(has_n_sites) == 3 + # the first 2 HasNSites will have 1 site but + # the Kron-produced operator should have 2 sites + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 2 + + +def test_mult_square_different_sites(): + # Ensure that two operators of different sizes produce + # NoSites as a type. Note that a better solution would be + # to implement a special error type in the type lattice + # but this would introduce some complexity later on + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + # Kron to make nsites = 2 operator + (hh := squin.op.stmts.Kron(h0.result, h1.result)), + # apply Mult on HasNSites(2) and HasNSites(1) + (invalid_op := squin.op.stmts.Mult(hh.result, h1.result)), + (func.Return(invalid_op.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + nsites_types = list(nsites_frame.entries.values()) + + has_n_sites = [] + no_sites = [] + for nsite_type in nsites_types: + if isinstance(nsite_type, nsites.NumberSites): + has_n_sites.append(nsite_type) + elif isinstance(nsite_type, nsites.NoSites): + no_sites.append(nsite_type) + + assert len(has_n_sites) == 3 + # HasNSites(1) for Hadamards, 2 for Kron result + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 2 + # One from function itself, another from invalid mult + assert len(no_sites) == 2 + + +def test_rot(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (angle := as_float(0.2)), + (rot_h := squin.op.stmts.Rot(axis=h0.result, angle=angle.result)), + (func.Return(rot_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + # Rot should just propagate whatever Sites type is there + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + + +def test_scale(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (factor := as_float(0.2)), + (rot_h := squin.op.stmts.Scale(op=h0.result, factor=factor.result)), + (func.Return(rot_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.NumberSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + # Rot should just propagate whatever Sites type is there + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1