Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions squin_op_playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kirin.dialects import py, func

from bloqade import squin
from bloqade.squin.analysis import shape
from bloqade.squin.analysis import nsites


def as_int(value: int):
Expand Down Expand Up @@ -51,7 +51,7 @@ def as_int(value: int):
constructed_method.print(analysis=address_frame.entries)
"""

shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis(
shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
constructed_method, no_raise=False
)

Expand Down
1 change: 0 additions & 1 deletion src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class AddressAnalysis(Forward[Address]):
def initialize(self):
super().initialize()
self.next_address: int = 0
print(self.registry.statements)
return self

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Need this for impl registration to work properly!
from . import impls as impls
from .analysis import ShapeAnalysis as ShapeAnalysis
from .analysis import NSitesAnalysis as NSitesAnalysis
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
from kirin.analysis.forward import ForwardFrame

from bloqade.squin.op.types import OpType
from bloqade.squin.op.traits import Sized, HasSize
from bloqade.squin.op.traits import Sites, HasNSitesTrait

from .lattice import Shape, NoShape, OpShape
from .lattice import NSites, NoSites, HasNSites


class ShapeAnalysis(Forward[Shape]):
class NSitesAnalysis(Forward[NSites]):

keys = ["op.shape"]
lattice = Shape

def initialize(self):
super().initialize()
return self
keys = ["op.nsites"]
lattice = NSites

# Take a page from const prop in Kirin,
# I can get the data I want from the SizedTrait
Expand All @@ -28,20 +24,20 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
method = self.lookup_registry(frame, stmt)
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasSize):
has_size_inst = stmt.get_trait(HasSize)
size = has_size_inst.get_size(stmt)
return (OpShape(size=size),)
elif stmt.has_trait(Sized):
size = stmt.get_trait(Sized)
return (OpShape(size=size.data),)
elif stmt.has_trait(HasNSitesTrait):
has_n_sites_trait = stmt.get_trait(HasNSitesTrait)
sites = has_n_sites_trait.get_sites(stmt)
return (HasNSites(sites=sites),)
elif stmt.has_trait(Sites):
sites_trait = stmt.get_trait(Sites)
return (HasNSites(sites=sites_trait.data),)
else:
return (NoShape(),)
return (NoSites(),)

# For when no implementation is found for the statement
def eval_stmt_fallback(
self, frame: ForwardFrame[Shape], stmt: ir.Statement
) -> tuple[Shape, ...]: # some form of Shape will go back into the frame
self, frame: ForwardFrame[NSites], stmt: ir.Statement
) -> tuple[NSites, ...]: # some form of Shape will go back into the frame
return tuple(
(
self.lattice.top()
Expand All @@ -51,6 +47,6 @@ def eval_stmt_fallback(
for result in stmt.results
)

def run_method(self, method: ir.Method, args: tuple[Shape, ...]):
def run_method(self, method: ir.Method, args: tuple[NSites, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
69 changes: 69 additions & 0 deletions src/bloqade/squin/analysis/nsites/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import cast

from kirin import ir, interp

from bloqade.squin import op

from .lattice import (
NoSites,
HasNSites,
)
from .analysis import NSitesAnalysis


@op.dialect.register(key="op.nsites")
class SquinOp(interp.MethodTable):

@interp.impl(op.stmts.Kron)
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
new_n_sites = lhs.sites + rhs.sites
return (HasNSites(sites=new_n_sites),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Mult)
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)

if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
lhs_sites = lhs.sites
rhs_sites = rhs.sites
# I originally considered throwing an exception here
# but Xiu-zhe (Roger) Luo has pointed out it would be
# a much better UX to add a type element that
# could explicitly indicate the error. The downside
# is you'll have some added complexity in the type lattice.
if lhs_sites != rhs_sites:
return (NoSites(),)
else:
return (HasNSites(sites=lhs_sites + rhs_sites),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Control)
def control(
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
):
op_sites = frame.get(stmt.op)

if isinstance(op_sites, HasNSites):
n_sites = op_sites.sites
n_controls_attr = stmt.get_attr_or_prop("n_controls")
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
return (HasNSites(sites=n_sites + n_controls),)
else:
return (NoSites(),)

@interp.impl(op.stmts.Rot)
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
op_sites = frame.get(stmt.axis)
return (op_sites,)

@interp.impl(op.stmts.Scale)
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
op_sites = frame.get(stmt.op)
return (op_sites,)
49 changes: 49 additions & 0 deletions src/bloqade/squin/analysis/nsites/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class NSites(
SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"]
):
@classmethod
def bottom(cls) -> "NSites":
return NoSites()

@classmethod
def top(cls) -> "NSites":
return AnySites()


@final
@dataclass
class NoSites(NSites, metaclass=SingletonMeta):

def is_subseteq(self, other: NSites) -> bool:
return True


@final
@dataclass
class AnySites(NSites, metaclass=SingletonMeta):

def is_subseteq(self, other: NSites) -> bool:
return isinstance(other, NSites)


@final
@dataclass
class HasNSites(NSites):
sites: int

def is_subseteq(self, other: NSites) -> bool:
if isinstance(other, HasNSites):
return self.sites == other.sites
return False
68 changes: 0 additions & 68 deletions src/bloqade/squin/analysis/shape/impls.py

This file was deleted.

49 changes: 0 additions & 49 deletions src/bloqade/squin/analysis/shape/lattice.py

This file was deleted.

14 changes: 7 additions & 7 deletions src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.decl import info, statement

from .types import OpType
from .traits import Sized, HasSize, Unitary, MaybeUnitary
from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait
from .complex import Complex
from ._dialect import dialect

Expand Down Expand Up @@ -77,21 +77,21 @@ class Rot(CompositeOp):

@statement(dialect=dialect)
class Identity(CompositeOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()})
size: int = info.attribute()
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()})
sites: int = info.attribute()
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantOp(PrimitiveOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)})
result: ir.ResultValue = info.result(OpType)


@statement
class ConstantUnitary(ConstantOp):
traits = frozenset(
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)}
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)}
)


Expand All @@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand All @@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp):
$$
"""

traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
theta: ir.SSAValue = info.argument(types.Float)
result: ir.ResultValue = info.result(OpType)

Expand Down
Loading