Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions squin_op_playground.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from kirin import ir, types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you wanna delete the playground?

Copy link
Contributor Author

@johnzl-777 johnzl-777 Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Thank you for catching that. I should probably do a local ignore for any _playground.py files moving forward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually use a main.py as my playground it's in the .gitignore already IIRC

from kirin.passes import Fold
from kirin.dialects import py, func

from bloqade import squin
from bloqade.squin.analysis import shape


def as_int(value: int):
return py.constant.Constant(value=value)


squin_with_qasm_core = squin.groups.wired.add(py)

stmts: list[ir.Statement] = [
(h0 := squin.op.stmts.H()),
(h1 := squin.op.stmts.H()),
(hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)),
(chh := squin.op.stmts.Control(hh.result, n_controls=1)),
(factor := as_int(1)),
(schh := squin.op.stmts.Scale(chh.result, factor=factor.result)),
(func.Return(schh.result)),
]

block = ir.Block(stmts)
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
func_wrapper = func.Function(
sym_name="main",
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
body=ir.Region(blocks=block),
)

constructed_method = ir.Method(
mod=None,
py_func=None,
sym_name="main",
dialects=squin_with_qasm_core,
code=func_wrapper,
arg_names=[],
)

fold_pass = Fold(squin_with_qasm_core)
fold_pass(constructed_method)

""""
address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis(
constructed_method, no_raise=False
)


constructed_method.print(analysis=address_frame.entries)
"""

shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis(
constructed_method, no_raise=False
)


constructed_method.print(analysis=shape_frame.entries)
1 change: 1 addition & 0 deletions src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AddressAnalysis(Forward[Address]):
def initialize(self):
super().initialize()
self.next_address: int = 0
print(self.registry.statements)
return self

@property
Expand Down
3 changes: 3 additions & 0 deletions src/bloqade/squin/analysis/shape/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Need this for impl registration to work properly!
from . import impls as impls
from .analysis import ShapeAnalysis as ShapeAnalysis
56 changes: 56 additions & 0 deletions src/bloqade/squin/analysis/shape/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# from typing import cast

from kirin import ir
from kirin.analysis import Forward
from kirin.analysis.forward import ForwardFrame

from bloqade.squin.op.types import OpType
from bloqade.squin.op.traits import Sized, HasSize

from .lattice import Shape, NoShape, OpShape


class ShapeAnalysis(Forward[Shape]):

keys = ["op.shape"]
lattice = Shape

def initialize(self):
super().initialize()
return self

# Take a page from const prop in Kirin,
# I can get the data I want from the SizedTrait
# and go from there

## This gets called before the registry look up
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
method = self.lookup_registry(frame, stmt)
if method is not None:
return method(self, frame, stmt)
elif stmt.has_trait(HasSize):
has_size_inst = stmt.get_trait(HasSize)
size = has_size_inst.get_size(stmt)
return (OpShape(size=size),)
elif stmt.has_trait(Sized):
size = stmt.get_trait(Sized)
return (OpShape(size=size.data),)
else:
return (NoShape(),)

# For when no implementation is found for the statement
def eval_stmt_fallback(
self, frame: ForwardFrame[Shape], stmt: ir.Statement
) -> tuple[Shape, ...]: # some form of Shape will go back into the frame
return tuple(
(
self.lattice.top()
if result.type.is_subseteq(OpType)
else self.lattice.bottom()
)
for result in stmt.results
)

def run_method(self, method: ir.Method, args: tuple[Shape, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
68 changes: 68 additions & 0 deletions src/bloqade/squin/analysis/shape/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import cast

from kirin import ir, interp

from bloqade.squin import op

from .lattice import (
NoShape,
OpShape,
)
from .analysis import ShapeAnalysis


@op.dialect.register(key="op.shape")
class SquinOp(interp.MethodTable):

@interp.impl(op.stmts.Kron)
def kron(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, OpShape) and isinstance(rhs, OpShape):
new_size = lhs.size + rhs.size
return (OpShape(size=new_size),)
else:
return (NoShape(),)

@interp.impl(op.stmts.Mult)
def mult(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)

if isinstance(lhs, OpShape) and isinstance(rhs, OpShape):
lhs_size = lhs.size
rhs_size = rhs.size
# Sized trait implicitly enforces that
# all operators are square matrices,
# not sure if it's worth raising an exception here
# or just letting this propagate...
if lhs_size != rhs_size:
return (NoShape(),)
else:
return (OpShape(size=lhs_size + rhs_size),)
else:
return (NoShape(),)

@interp.impl(op.stmts.Control)
def control(
self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Control
):
op_shape = frame.get(stmt.op)

if isinstance(op_shape, OpShape):
op_size = op_shape.size
n_controls_attr = stmt.get_attr_or_prop("n_controls")
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
return (OpShape(size=op_size + n_controls),)
else:
return (NoShape(),)

@interp.impl(op.stmts.Rot)
def rot(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
op_shape = frame.get(stmt.axis)
return (op_shape,)

@interp.impl(op.stmts.Scale)
def scale(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
op_shape = frame.get(stmt.op)
return (op_shape,)
49 changes: 49 additions & 0 deletions src/bloqade/squin/analysis/shape/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class Shape(
SimpleJoinMixin["Shape"], SimpleMeetMixin["Shape"], BoundedLattice["Shape"]
):
@classmethod
def bottom(cls) -> "Shape":
return NoShape()

@classmethod
def top(cls) -> "Shape":
return AnyShape()


@final
@dataclass
class NoShape(Shape, metaclass=SingletonMeta):

def is_subseteq(self, other: Shape) -> bool:
return True


@final
@dataclass
class AnyShape(Shape, metaclass=SingletonMeta):

def is_subseteq(self, other: Shape) -> bool:
return isinstance(other, Shape)


@final
@dataclass
class OpShape(Shape):
size: int

def is_subseteq(self, other: Shape) -> bool:
if isinstance(other, OpShape):
return self.size == other.size
return False