Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/bloqade/squin/analysis/nsites/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Need this for impl registration to work properly!
from . import impls as impls
from .lattice import (
NoSites as NoSites,
AnySites as AnySites,
NumberSites as NumberSites,
)
from .analysis import NSitesAnalysis as NSitesAnalysis
52 changes: 52 additions & 0 deletions src/bloqade/squin/analysis/nsites/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# from typing import cast

from kirin import ir
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame

from bloqade.squin.op.types import OpType
from bloqade.squin.op.traits import HasSites, FixedSites

from .lattice import Sites, NoSites, NumberSites


class NSitesAnalysis(Forward[Sites]):

keys = ["op.nsites"]
lattice = Sites

# Take a page from const prop in Kirin,
# I can get the data I want from the SizedTrait
# and go from there

## This gets called before the registry look up
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
method = self.lookup_registry(frame, stmt)
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasSites):
has_sites_trait = stmt.get_trait(HasSites)
sites = has_sites_trait.get_sites(stmt)
return (NumberSites(sites=sites),)

Check warning on line 30 in src/bloqade/squin/analysis/nsites/analysis.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/analysis.py#L28-L30

Added lines #L28 - L30 were not covered by tests
elif stmt.has_trait(FixedSites):
sites_trait = stmt.get_trait(FixedSites)
return (NumberSites(sites=sites_trait.data),)
else:
return (NoSites(),)

# For when no implementation is found for the statement
def eval_stmt_fallback(
self, frame: ForwardFrame[Sites], stmt: ir.Statement
) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
return tuple(

Check warning on line 41 in src/bloqade/squin/analysis/nsites/analysis.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/analysis.py#L41

Added line #L41 was not covered by tests
(
self.lattice.top()
if result.type.is_subseteq(OpType)
else self.lattice.bottom()
)
for result in stmt.results
)

def run_method(self, method: ir.Method, args: tuple[Sites, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69 changes: 69 additions & 0 deletions src/bloqade/squin/analysis/nsites/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import cast

from kirin import ir, interp

from bloqade.squin import op

from .lattice import (
NoSites,
NumberSites,
)
from .analysis import NSitesAnalysis


@op.dialect.register(key="op.nsites")
class SquinOp(interp.MethodTable):

@interp.impl(op.stmts.Kron)
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
new_n_sites = lhs.sites + rhs.sites
return (NumberSites(sites=new_n_sites),)
else:
return (NoSites(),)

Check warning on line 25 in src/bloqade/squin/analysis/nsites/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/impls.py#L25

Added line #L25 was not covered by tests

@interp.impl(op.stmts.Mult)
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)

if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
lhs_sites = lhs.sites
rhs_sites = rhs.sites
# I originally considered throwing an exception here
# but Xiu-zhe (Roger) Luo has pointed out it would be
# a much better UX to add a type element that
# could explicitly indicate the error. The downside
# is you'll have some added complexity in the type lattice.
if lhs_sites != rhs_sites:
return (NoSites(),)
else:
return (NumberSites(sites=lhs_sites + rhs_sites),)
else:
return (NoSites(),)

Check warning on line 45 in src/bloqade/squin/analysis/nsites/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/impls.py#L45

Added line #L45 was not covered by tests

@interp.impl(op.stmts.Control)
def control(
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
):
op_sites = frame.get(stmt.op)

if isinstance(op_sites, NumberSites):
n_sites = op_sites.sites
n_controls_attr = stmt.get_attr_or_prop("n_controls")
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
return (NumberSites(sites=n_sites + n_controls),)
else:
return (NoSites(),)

Check warning on line 59 in src/bloqade/squin/analysis/nsites/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/impls.py#L59

Added line #L59 was not covered by tests

@interp.impl(op.stmts.Rot)
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
op_sites = frame.get(stmt.axis)
return (op_sites,)

@interp.impl(op.stmts.Scale)
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
op_sites = frame.get(stmt.op)
return (op_sites,)
49 changes: 49 additions & 0 deletions src/bloqade/squin/analysis/nsites/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class Sites(
SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"]
):
@classmethod
def bottom(cls) -> "Sites":
return NoSites()

@classmethod
def top(cls) -> "Sites":
return AnySites()

Check warning on line 22 in src/bloqade/squin/analysis/nsites/lattice.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/lattice.py#L22

Added line #L22 was not covered by tests


@final
@dataclass
class NoSites(Sites, metaclass=SingletonMeta):

def is_subseteq(self, other: Sites) -> bool:
return True


@final
@dataclass
class AnySites(Sites, metaclass=SingletonMeta):

def is_subseteq(self, other: Sites) -> bool:
return isinstance(other, Sites)

Check warning on line 38 in src/bloqade/squin/analysis/nsites/lattice.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/lattice.py#L38

Added line #L38 was not covered by tests


@final
@dataclass
class NumberSites(Sites):
sites: int

def is_subseteq(self, other: Sites) -> bool:
if isinstance(other, NumberSites):
return self.sites == other.sites
return False

Check warning on line 49 in src/bloqade/squin/analysis/nsites/lattice.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/analysis/nsites/lattice.py#L47-L49

Added lines #L47 - L49 were not covered by tests
16 changes: 9 additions & 7 deletions src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.decl import info, statement

from .types import OpType
from .traits import Sized, HasSize, Unitary, MaybeUnitary
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
from .complex import Complex
from ._dialect import dialect

Expand Down Expand Up @@ -77,21 +77,23 @@ class Rot(CompositeOp):

@statement(dialect=dialect)
class Identity(CompositeOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()})
size: int = info.attribute()
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSites()})
sites: int = info.attribute()
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantOp(PrimitiveOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)})
traits = frozenset(
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), FixedSites(1)}
)
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantUnitary(ConstantOp):
traits = frozenset(
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)}
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), FixedSites(1)}
)


Expand All @@ -105,7 +107,7 @@ class PhaseOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand All @@ -120,7 +122,7 @@ class ShiftOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand Down
16 changes: 8 additions & 8 deletions src/bloqade/squin/op/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@


@dataclass(frozen=True)
class Sized(ir.StmtTrait):
class FixedSites(ir.StmtTrait):
data: int


@dataclass(frozen=True)
class HasSize(ir.StmtTrait):
"""An operator with a `size` attribute."""
class HasSites(ir.StmtTrait):
"""An operator with a `sites` attribute."""

def get_size(self, stmt: ir.Statement):
attr = stmt.get_attr_or_prop("size")
def get_sites(self, stmt: ir.Statement):
attr = stmt.get_attr_or_prop("sites")

Check warning on line 17 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L17

Added line #L17 was not covered by tests
if attr is None:
raise ValueError(f"Missing size attribute in {stmt}")
raise ValueError(f"Missing sites attribute in {stmt}")

Check warning on line 19 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L19

Added line #L19 was not covered by tests
return cast(ir.PyAttr[int], attr).data

def set_size(self, stmt: ir.Statement, value: int):
stmt.attributes["size"] = ir.PyAttr(value)
def set_sites(self, stmt: ir.Statement, value: int):
stmt.attributes["sites"] = ir.PyAttr(value)

Check warning on line 23 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L23

Added line #L23 was not covered by tests
return


Expand Down
Loading