Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions squin_op_playground.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from kirin import ir, types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you wanna delete the playground?

Copy link
Contributor Author

@johnzl-777 johnzl-777 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Thank you for catching that. I should probably do a local ignore for any _playground.py files moving forward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually use a main.py as my playground it's in the .gitignore already IIRC

from kirin.passes import Fold
from kirin.dialects import py, func

from bloqade import squin
from bloqade.squin.analysis import nsites


def as_int(value: int):
return py.constant.Constant(value=value)


squin_with_qasm_core = squin.groups.wired.add(py)

stmts: list[ir.Statement] = [
(h0 := squin.op.stmts.H()),
(h1 := squin.op.stmts.H()),
(hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)),
(chh := squin.op.stmts.Control(hh.result, n_controls=1)),
(factor := as_int(1)),
(schh := squin.op.stmts.Scale(chh.result, factor=factor.result)),
(func.Return(schh.result)),
]

block = ir.Block(stmts)
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
func_wrapper = func.Function(
sym_name="main",
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
body=ir.Region(blocks=block),
)

constructed_method = ir.Method(
mod=None,
py_func=None,
sym_name="main",
dialects=squin_with_qasm_core,
code=func_wrapper,
arg_names=[],
)

fold_pass = Fold(squin_with_qasm_core)
fold_pass(constructed_method)

""""
address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis(
constructed_method, no_raise=False
)
constructed_method.print(analysis=address_frame.entries)
"""

shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
constructed_method, no_raise=False
)


constructed_method.print(analysis=shape_frame.entries)
3 changes: 3 additions & 0 deletions src/bloqade/squin/analysis/nsites/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Need this for impl registration to work properly!
from . import impls as impls
from .analysis import NSitesAnalysis as NSitesAnalysis
52 changes: 52 additions & 0 deletions src/bloqade/squin/analysis/nsites/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# from typing import cast

from kirin import ir
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame

from bloqade.squin.op.types import OpType
from bloqade.squin.op.traits import Sites, HasNSitesTrait

from .lattice import NSites, NoSites, HasNSites


class NSitesAnalysis(Forward[NSites]):

keys = ["op.nsites"]
lattice = NSites

# Take a page from const prop in Kirin,
# I can get the data I want from the SizedTrait
# and go from there

## This gets called before the registry look up
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
method = self.lookup_registry(frame, stmt)
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasNSitesTrait):
has_n_sites_trait = stmt.get_trait(HasNSitesTrait)
sites = has_n_sites_trait.get_sites(stmt)
return (HasNSites(sites=sites),)
elif stmt.has_trait(Sites):
sites_trait = stmt.get_trait(Sites)
return (HasNSites(sites=sites_trait.data),)
else:
return (NoSites(),)

# For when no implementation is found for the statement
def eval_stmt_fallback(
self, frame: ForwardFrame[NSites], stmt: ir.Statement
) -> tuple[NSites, ...]: # some form of Shape will go back into the frame
return tuple(
(
self.lattice.top()
if result.type.is_subseteq(OpType)
else self.lattice.bottom()
)
for result in stmt.results
)

def run_method(self, method: ir.Method, args: tuple[NSites, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69 changes: 69 additions & 0 deletions src/bloqade/squin/analysis/nsites/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import cast

from kirin import ir, interp

from bloqade.squin import op

from .lattice import (
NoSites,
HasNSites,
)
from .analysis import NSitesAnalysis


@op.dialect.register(key="op.nsites")
class SquinOp(interp.MethodTable):

@interp.impl(op.stmts.Kron)
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
new_n_sites = lhs.sites + rhs.sites
return (HasNSites(sites=new_n_sites),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Mult)
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)

if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
lhs_sites = lhs.sites
rhs_sites = rhs.sites
# I originally considered throwing an exception here
# but Xiu-zhe (Roger) Luo has pointed out it would be
# a much better UX to add a type element that
# could explicitly indicate the error. The downside
# is you'll have some added complexity in the type lattice.
if lhs_sites != rhs_sites:
return (NoSites(),)
else:
return (HasNSites(sites=lhs_sites + rhs_sites),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Control)
def control(
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
):
op_sites = frame.get(stmt.op)

if isinstance(op_sites, HasNSites):
n_sites = op_sites.sites
n_controls_attr = stmt.get_attr_or_prop("n_controls")
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
return (HasNSites(sites=n_sites + n_controls),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Rot)
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
op_sites = frame.get(stmt.axis)
return (op_sites,)

@interp.impl(op.stmts.Scale)
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
op_sites = frame.get(stmt.op)
return (op_sites,)
49 changes: 49 additions & 0 deletions src/bloqade/squin/analysis/nsites/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class NSites(
SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"]
):
@classmethod
def bottom(cls) -> "NSites":
return NoSites()

@classmethod
def top(cls) -> "NSites":
return AnySites()


@final
@dataclass
class NoSites(NSites, metaclass=SingletonMeta):

def is_subseteq(self, other: NSites) -> bool:
return True


@final
@dataclass
class AnySites(NSites, metaclass=SingletonMeta):

def is_subseteq(self, other: NSites) -> bool:
return isinstance(other, NSites)


@final
@dataclass
class HasNSites(NSites):
sites: int

def is_subseteq(self, other: NSites) -> bool:
if isinstance(other, HasNSites):
return self.sites == other.sites
return False
14 changes: 7 additions & 7 deletions src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.decl import info, statement

from .types import OpType
from .traits import Sized, HasSize, Unitary, MaybeUnitary
from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait
from .complex import Complex
from ._dialect import dialect

Expand Down Expand Up @@ -77,21 +77,21 @@ class Rot(CompositeOp):

@statement(dialect=dialect)
class Identity(CompositeOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()})
size: int = info.attribute()
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()})
sites: int = info.attribute()
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantOp(PrimitiveOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)})
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantUnitary(ConstantOp):
traits = frozenset(
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)}
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)}
)


Expand All @@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand All @@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand Down
16 changes: 8 additions & 8 deletions src/bloqade/squin/op/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@


@dataclass(frozen=True)
class Sized(ir.StmtTrait):
class Sites(ir.StmtTrait):
data: int


@dataclass(frozen=True)
class HasSize(ir.StmtTrait):
"""An operator with a `size` attribute."""
class HasNSitesTrait(ir.StmtTrait):
"""An operator with a `sites` attribute."""

def get_size(self, stmt: ir.Statement):
attr = stmt.get_attr_or_prop("size")
def get_sites(self, stmt: ir.Statement):
attr = stmt.get_attr_or_prop("sites")

Check warning on line 17 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L17

Added line #L17 was not covered by tests
if attr is None:
raise ValueError(f"Missing size attribute in {stmt}")
raise ValueError(f"Missing sites attribute in {stmt}")

Check warning on line 19 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L19

Added line #L19 was not covered by tests
return cast(ir.PyAttr[int], attr).data

def set_size(self, stmt: ir.Statement, value: int):
stmt.attributes["size"] = ir.PyAttr(value)
def set_sites(self, stmt: ir.Statement, value: int):
stmt.attributes["sites"] = ir.PyAttr(value)

Check warning on line 23 in src/bloqade/squin/op/traits.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/traits.py#L23

Added line #L23 was not covered by tests
return


Expand Down