From d9c7bbab25302e5d014058981af95b5672dc0631 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 8 Apr 2025 14:42:50 -0400 Subject: [PATCH 01/18] initial shape lattice --- src/bloqade/squin/analysis/shape/__init__.py | 0 src/bloqade/squin/analysis/shape/impls.py | 0 src/bloqade/squin/analysis/shape/lattice.py | 49 ++++++++++++++++++++ 3 files changed, 49 insertions(+) create mode 100644 src/bloqade/squin/analysis/shape/__init__.py create mode 100644 src/bloqade/squin/analysis/shape/impls.py create mode 100644 src/bloqade/squin/analysis/shape/lattice.py diff --git a/src/bloqade/squin/analysis/shape/__init__.py b/src/bloqade/squin/analysis/shape/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/analysis/shape/lattice.py b/src/bloqade/squin/analysis/shape/lattice.py new file mode 100644 index 00000000..afd35251 --- /dev/null +++ b/src/bloqade/squin/analysis/shape/lattice.py @@ -0,0 +1,49 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class Shape( + SimpleJoinMixin["Shape"], SimpleMeetMixin["Shape"], BoundedLattice["Shape"] +): + @classmethod + def bottom(cls) -> "Shape": + return NoShape() + + @classmethod + def top(cls) -> "Shape": + return AnyShape() + + +@final +@dataclass +class NoShape(Shape, metaclass=SingletonMeta): + + def is_subseteq(self, other: Shape) -> bool: + return True + + +@final +@dataclass +class AnyShape(Shape, metaclass=SingletonMeta): + + def is_subseteq(self, other: Shape) -> bool: + return isinstance(other, Shape) + + +@final +@dataclass +class OpShape(Shape): + size: int + + def is_subseteq(self, other: Shape) -> bool: + if isinstance(other, OpShape): + return self.size == other.size + return False From 13135a35bf0ee1dad9d4145e1a8478274b7b19a7 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 8 Apr 2025 15:20:27 -0400 Subject: [PATCH 02/18] more analysis work --- src/bloqade/squin/analysis/shape/analysis.py | 33 ++++++++++++++++++++ src/bloqade/squin/analysis/shape/impls.py | 19 +++++++++++ 2 files changed, 52 insertions(+) create mode 100644 src/bloqade/squin/analysis/shape/analysis.py diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py new file mode 100644 index 00000000..2a1eec6e --- /dev/null +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -0,0 +1,33 @@ +from kirin import ir, interp +from kirin.analysis import Forward +from kirin.analysis.forward import ForwardFrame + +from bloqade.squin.op.types import OpType + +from .lattice import Shape + + +class ShapeAnalysis(Forward[Shape]): + + keys = ["op.shape"] + lattice = Shape + + def initialize(self): + super().initialize + return self + + def eval_stmt_fallback( + self, frame: ForwardFrame[Shape], stmt: ir.Statement + ) -> tuple[Shape, ...] | interp.SpecialValue[Shape]: + return tuple( + ( + self.lattice.top() + if result.type.is_subseteq(OpType) + else self.lattice.bottom() + ) + for result in stmt.results + ) + + def run_method(self, method: ir.Method, args: tuple[Shape, ...]): + # NOTE: we do not support dynamic calls here, thus no need to propagate method object + return self.run_callable(method.code, (self.lattice.bottom(),) + args) diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py index e69de29b..56bc441f 100644 --- a/src/bloqade/squin/analysis/shape/impls.py +++ b/src/bloqade/squin/analysis/shape/impls.py @@ -0,0 +1,19 @@ +from kirin import interp + +from bloqade import squin + +""" from .lattice import ( + Shape, + NoShape, + OpShape, +) + +from .analysis import ShapeAnalysis """ + + +@squin.op.dialect.register(key="op.shape") +class SquinOp(interp.MethodTable): + + # Should be using the Sized trait + # that the statements have + pass From 9aa29645fd835fc406dac48f5779f7fa9c35e0c6 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 8 Apr 2025 16:19:40 -0400 Subject: [PATCH 03/18] getting there --- squin_op_playground.py | 53 ++++++++++++++++++++ src/bloqade/squin/analysis/shape/analysis.py | 21 +++++++- src/bloqade/squin/analysis/shape/impls.py | 7 ++- 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 squin_op_playground.py diff --git a/squin_op_playground.py b/squin_op_playground.py new file mode 100644 index 00000000..f4fc3d86 --- /dev/null +++ b/squin_op_playground.py @@ -0,0 +1,53 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func, ilist + +from bloqade import qasm2, squin +from bloqade.analysis import address + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +squin_with_qasm_core = squin.groups.wired.add(qasm2.core).add(ilist) + +stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w1 := squin.wire.Unwrap(qubit=q1.result)), + # Put them in an ilist and return to prevent elimination + # Put the wire into one operator + (op := squin.op.stmts.H()), + (v1 := squin.wire.Apply(op.result, w1.result)), + (func.Return(v1.results[0])), +] + +block = ir.Block(stmts) +block.args.append_from(types.MethodType[[], types.NoneType], "main_self") +func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.wire.WireType), + body=ir.Region(blocks=block), +) + +constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_qasm_core, + code=func_wrapper, + arg_names=[], +) + +fold_pass = Fold(squin_with_qasm_core) +fold_pass(constructed_method) + +frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False +) diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 2a1eec6e..e2fbc91b 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -1,10 +1,13 @@ +# from typing import cast + from kirin import ir, interp from kirin.analysis import Forward from kirin.analysis.forward import ForwardFrame from bloqade.squin.op.types import OpType +from bloqade.squin.op.traits import Sized, HasSize -from .lattice import Shape +from .lattice import Shape, OpShape, AnyShape class ShapeAnalysis(Forward[Shape]): @@ -16,6 +19,22 @@ def initialize(self): super().initialize return self + # Take a page from const prop in Kirin, + # I can get the data I want from the SizedTrait + # and go from there + def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): + if stmt.has_trait(Sized): + size = stmt.get_trait(Sized) + return (OpShape(size=size),) + # Handle op.Identity + elif stmt.has_trait(HasSize): + # Caution! This can return None + has_size_inst = stmt.get_trait(HasSize) + size = has_size_inst.get_size(stmt) + return (OpShape(size=size),) + else: + return (AnyShape(),) + def eval_stmt_fallback( self, frame: ForwardFrame[Shape], stmt: ir.Statement ) -> tuple[Shape, ...] | interp.SpecialValue[Shape]: diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py index 56bc441f..4c33ff78 100644 --- a/src/bloqade/squin/analysis/shape/impls.py +++ b/src/bloqade/squin/analysis/shape/impls.py @@ -13,7 +13,12 @@ @squin.op.dialect.register(key="op.shape") class SquinOp(interp.MethodTable): + pass # Should be using the Sized trait # that the statements have - pass + + # Need to keep in mind that Identity + # has a HasSize() trait with "size:int" + # as the corresponding attribute to query + # @interp.impl(squin.op.stmts.ConstantUnitary) From 9ae811cf6945530b774b500c95f86aa4f3df9e0d Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 8 Apr 2025 16:54:33 -0400 Subject: [PATCH 04/18] proper handling of Sized(1) trait operators --- squin_op_playground.py | 27 +++++++++++++++----- src/bloqade/squin/analysis/shape/__init__.py | 1 + src/bloqade/squin/analysis/shape/analysis.py | 8 +++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/squin_op_playground.py b/squin_op_playground.py index f4fc3d86..b9c0c7a5 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -4,6 +4,7 @@ from bloqade import qasm2, squin from bloqade.analysis import address +from bloqade.squin.analysis import shape def as_int(value: int): @@ -18,21 +19,24 @@ def as_int(value: int): (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), # Get qubits out (idx0 := as_int(0)), - (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), # Unwrap to get wires - (w1 := squin.wire.Unwrap(qubit=q1.result)), - # Put them in an ilist and return to prevent elimination - # Put the wire into one operator + (w1 := squin.wire.Unwrap(qubit=q0.result)), + # Pass wire into operator (op := squin.op.stmts.H()), (v1 := squin.wire.Apply(op.result, w1.result)), - (func.Return(v1.results[0])), + # Test Identity + (id := squin.op.stmts.Identity(size=1)), + (v2 := squin.wire.Apply(id.result, v1.results[0])), + # Keep Passing Operators + (func.Return(v2.results[0])), ] block = ir.Block(stmts) block.args.append_from(types.MethodType[[], types.NoneType], "main_self") func_wrapper = func.Function( sym_name="main", - signature=func.Signature(inputs=(), output=squin.wire.WireType), + signature=func.Signature(inputs=(), output=ilist.IListType), body=ir.Region(blocks=block), ) @@ -51,3 +55,14 @@ def as_int(value: int): frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) + +frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False +) + +""" +frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False +""" + +constructed_method.print(analysis=frame.entries) diff --git a/src/bloqade/squin/analysis/shape/__init__.py b/src/bloqade/squin/analysis/shape/__init__.py index e69de29b..ec943466 100644 --- a/src/bloqade/squin/analysis/shape/__init__.py +++ b/src/bloqade/squin/analysis/shape/__init__.py @@ -0,0 +1 @@ +from .analysis import ShapeAnalysis as ShapeAnalysis diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index e2fbc91b..6ca375da 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -7,7 +7,7 @@ from bloqade.squin.op.types import OpType from bloqade.squin.op.traits import Sized, HasSize -from .lattice import Shape, OpShape, AnyShape +from .lattice import Shape, NoShape, OpShape class ShapeAnalysis(Forward[Shape]): @@ -16,7 +16,7 @@ class ShapeAnalysis(Forward[Shape]): lattice = Shape def initialize(self): - super().initialize + super().initialize() return self # Take a page from const prop in Kirin, @@ -25,7 +25,7 @@ def initialize(self): def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): if stmt.has_trait(Sized): size = stmt.get_trait(Sized) - return (OpShape(size=size),) + return (OpShape(size=size.data),) # Handle op.Identity elif stmt.has_trait(HasSize): # Caution! This can return None @@ -33,7 +33,7 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): size = has_size_inst.get_size(stmt) return (OpShape(size=size),) else: - return (AnyShape(),) + return (NoShape(),) def eval_stmt_fallback( self, frame: ForwardFrame[Shape], stmt: ir.Statement From 0c20b959ab46b7331bd6128f656af4a5dd309d5c Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 8 Apr 2025 22:05:59 -0400 Subject: [PATCH 05/18] Completed implementation but not quite working --- squin_op_playground.py | 28 +++----- src/bloqade/squin/analysis/shape/analysis.py | 10 +-- src/bloqade/squin/analysis/shape/impls.py | 73 ++++++++++++++++---- 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/squin_op_playground.py b/squin_op_playground.py index b9c0c7a5..840eb888 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -1,8 +1,8 @@ from kirin import ir, types from kirin.passes import Fold -from kirin.dialects import py, func, ilist +from kirin.dialects import py, func -from bloqade import qasm2, squin +from bloqade import squin from bloqade.analysis import address from bloqade.squin.analysis import shape @@ -11,32 +11,20 @@ def as_int(value: int): return py.constant.Constant(value=value) -squin_with_qasm_core = squin.groups.wired.add(qasm2.core).add(ilist) +squin_with_qasm_core = squin.groups.wired stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(1)), - (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), - # Unwrap to get wires - (w1 := squin.wire.Unwrap(qubit=q0.result)), - # Pass wire into operator - (op := squin.op.stmts.H()), - (v1 := squin.wire.Apply(op.result, w1.result)), - # Test Identity - (id := squin.op.stmts.Identity(size=1)), - (v2 := squin.wire.Apply(id.result, v1.results[0])), - # Keep Passing Operators - (func.Return(v2.results[0])), + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + (hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)), + (func.Return(hh.result)), ] block = ir.Block(stmts) block.args.append_from(types.MethodType[[], types.NoneType], "main_self") func_wrapper = func.Function( sym_name="main", - signature=func.Signature(inputs=(), output=ilist.IListType), + signature=func.Signature(inputs=(), output=squin.op.types.OpType), body=ir.Region(blocks=block), ) diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 6ca375da..2591c880 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -23,15 +23,17 @@ def initialize(self): # I can get the data I want from the SizedTrait # and go from there def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): - if stmt.has_trait(Sized): - size = stmt.get_trait(Sized) - return (OpShape(size=size.data),) - # Handle op.Identity + method = self.lookup_registry(frame, stmt) + if method is not None: + return method(self, frame, stmt) elif stmt.has_trait(HasSize): # Caution! This can return None has_size_inst = stmt.get_trait(HasSize) size = has_size_inst.get_size(stmt) return (OpShape(size=size),) + elif stmt.has_trait(Sized): + size = stmt.get_trait(Sized) + return (OpShape(size=size.data),) else: return (NoShape(),) diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py index 4c33ff78..9bbd9cef 100644 --- a/src/bloqade/squin/analysis/shape/impls.py +++ b/src/bloqade/squin/analysis/shape/impls.py @@ -1,24 +1,69 @@ -from kirin import interp +from typing import cast -from bloqade import squin +from kirin import ir, interp -""" from .lattice import ( - Shape, +from bloqade.squin import op + +from .lattice import ( NoShape, OpShape, ) - -from .analysis import ShapeAnalysis """ +from .analysis import ShapeAnalysis -@squin.op.dialect.register(key="op.shape") +@op.dialect.register(key="op.shape") class SquinOp(interp.MethodTable): - pass - # Should be using the Sized trait - # that the statements have + @interp.impl(op.stmts.Kron) + def kron(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, OpShape) and isinstance(rhs, OpShape): + new_size = lhs.size + rhs.size + return (OpShape(size=new_size),) + else: + return (NoShape(),) + + @interp.impl(op.stmts.Mult) + def mult(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Mult): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + + if isinstance(lhs, OpShape) and isinstance(rhs, OpShape): + lhs_size = lhs.size + rhs_size = rhs.size + # Sized trait implicitly enforces that + # all operators are square matrices, + # not sure if it's worth raising an exception here + # or just letting this propagate... + if lhs_size != rhs_size: + return (NoShape(),) + else: + return (OpShape(size=lhs_size + rhs_size),) + else: + return (NoShape(),) + + @interp.impl(op.stmts.Control) + def control( + self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Control + ): + op_shape = frame.get(stmt.op) + + if isinstance(op_shape, OpShape): + op_size = op_shape.size + n_controls_attr = stmt.get_attr_or_prop("n_controls") + # raise exception if attribute is NOne + n_controls = cast(ir.PyAttr[int], n_controls_attr).data + return (OpShape(size=op_size + n_controls),) + else: + return (NoShape(),) + + @interp.impl(op.stmts.Rot) + def rot(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): + op_shape = frame.get(stmt.axis) + return op_shape - # Need to keep in mind that Identity - # has a HasSize() trait with "size:int" - # as the corresponding attribute to query - # @interp.impl(squin.op.stmts.ConstantUnitary) + @interp.impl(op.stmts.Scale) + def scale(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): + op_shape = frame.get(stmt.op) + return op_shape From 5e1c2211fc40e67ae101011f4ebd65bb80c145be Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 09:45:52 -0400 Subject: [PATCH 06/18] debugging empty registry problem --- squin_op_playground.py | 8 +++++--- src/bloqade/analysis/address/analysis.py | 1 + src/bloqade/squin/analysis/shape/analysis.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/squin_op_playground.py b/squin_op_playground.py index 840eb888..d82a2b54 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -40,11 +40,13 @@ def as_int(value: int): fold_pass = Fold(squin_with_qasm_core) fold_pass(constructed_method) -frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( +address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) -frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis( +constructed_method.print(analysis=address_frame.entries) + +shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) @@ -53,4 +55,4 @@ def as_int(value: int): constructed_method, no_raise=False """ -constructed_method.print(analysis=frame.entries) +constructed_method.print(analysis=shape_frame.entries) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index d1438b00..22d97eeb 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -22,6 +22,7 @@ class AddressAnalysis(Forward[Address]): def initialize(self): super().initialize() self.next_address: int = 0 + print(self.registry.statements) return self @property diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 2591c880..8c727a10 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -1,6 +1,6 @@ # from typing import cast -from kirin import ir, interp +from kirin import ir from kirin.analysis import Forward from kirin.analysis.forward import ForwardFrame @@ -17,12 +17,17 @@ class ShapeAnalysis(Forward[Shape]): def initialize(self): super().initialize() + print(self.registry.statements) return self # Take a page from const prop in Kirin, # I can get the data I want from the SizedTrait # and go from there + + ## This gets called before the registry look up def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): + # something fishy, registry pops up empty? + # This doesn't happen with the method = self.lookup_registry(frame, stmt) if method is not None: return method(self, frame, stmt) @@ -37,9 +42,10 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): else: return (NoShape(),) + # For when no implementation is found for the statement def eval_stmt_fallback( self, frame: ForwardFrame[Shape], stmt: ir.Statement - ) -> tuple[Shape, ...] | interp.SpecialValue[Shape]: + ) -> tuple[Shape, ...]: # some form of Shape will go back into the frame return tuple( ( self.lattice.top() From d765f093478558e72d65213522d4328f7bfb9762 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 13:44:18 -0400 Subject: [PATCH 07/18] fixed missing registry + improper impl returns --- squin_op_playground.py | 16 +++++++++------- src/bloqade/squin/analysis/shape/__init__.py | 2 ++ src/bloqade/squin/analysis/shape/analysis.py | 2 -- src/bloqade/squin/analysis/shape/impls.py | 4 ++-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/squin_op_playground.py b/squin_op_playground.py index d82a2b54..05cd1d4e 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -3,7 +3,6 @@ from kirin.dialects import py, func from bloqade import squin -from bloqade.analysis import address from bloqade.squin.analysis import shape @@ -11,13 +10,17 @@ def as_int(value: int): return py.constant.Constant(value=value) -squin_with_qasm_core = squin.groups.wired +squin_with_qasm_core = squin.groups.wired.add(py) stmts: list[ir.Statement] = [ (h0 := squin.op.stmts.H()), (h1 := squin.op.stmts.H()), (hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)), - (func.Return(hh.result)), + (chh := squin.op.stmts.Control(hh.result, n_controls=1)), + (factor := as_int(1)), + # schh for some reason causes it to blow up + (schh := squin.op.stmts.Scale(chh.result, factor=factor.result)), + (func.Return(schh.result)), ] block = ir.Block(stmts) @@ -40,19 +43,18 @@ def as_int(value: int): fold_pass = Fold(squin_with_qasm_core) fold_pass(constructed_method) +"""" address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) + constructed_method.print(analysis=address_frame.entries) +""" shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) -""" -frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False -""" constructed_method.print(analysis=shape_frame.entries) diff --git a/src/bloqade/squin/analysis/shape/__init__.py b/src/bloqade/squin/analysis/shape/__init__.py index ec943466..08c8fa45 100644 --- a/src/bloqade/squin/analysis/shape/__init__.py +++ b/src/bloqade/squin/analysis/shape/__init__.py @@ -1 +1,3 @@ +# Need this for impl registration to work properly! +from . import impls as impls from .analysis import ShapeAnalysis as ShapeAnalysis diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 8c727a10..5db3d355 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -26,8 +26,6 @@ def initialize(self): ## This gets called before the registry look up def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): - # something fishy, registry pops up empty? - # This doesn't happen with the method = self.lookup_registry(frame, stmt) if method is not None: return method(self, frame, stmt) diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py index 9bbd9cef..1250de30 100644 --- a/src/bloqade/squin/analysis/shape/impls.py +++ b/src/bloqade/squin/analysis/shape/impls.py @@ -61,9 +61,9 @@ def control( @interp.impl(op.stmts.Rot) def rot(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): op_shape = frame.get(stmt.axis) - return op_shape + return (op_shape,) @interp.impl(op.stmts.Scale) def scale(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): op_shape = frame.get(stmt.op) - return op_shape + return (op_shape,) From a12dae121ff19dfee2375ae649f752e555b66cd0 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 13:46:09 -0400 Subject: [PATCH 08/18] remove incorrect comment --- squin_op_playground.py | 1 - 1 file changed, 1 deletion(-) diff --git a/squin_op_playground.py b/squin_op_playground.py index 05cd1d4e..bdcb5128 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -18,7 +18,6 @@ def as_int(value: int): (hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)), (chh := squin.op.stmts.Control(hh.result, n_controls=1)), (factor := as_int(1)), - # schh for some reason causes it to blow up (schh := squin.op.stmts.Scale(chh.result, factor=factor.result)), (func.Return(schh.result)), ] From a1b0da026351d72b43dd7f67b49c2614da5622fa Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 14:15:53 -0400 Subject: [PATCH 09/18] remove incorrect comment --- src/bloqade/squin/analysis/shape/impls.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py index 1250de30..878d3a7d 100644 --- a/src/bloqade/squin/analysis/shape/impls.py +++ b/src/bloqade/squin/analysis/shape/impls.py @@ -52,7 +52,6 @@ def control( if isinstance(op_shape, OpShape): op_size = op_shape.size n_controls_attr = stmt.get_attr_or_prop("n_controls") - # raise exception if attribute is NOne n_controls = cast(ir.PyAttr[int], n_controls_attr).data return (OpShape(size=op_size + n_controls),) else: From 32cc7340ee266dbe0945d270b3ad6ea9909f8907 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 14:19:32 -0400 Subject: [PATCH 10/18] remove lingering print --- src/bloqade/squin/analysis/shape/analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 5db3d355..0e2fdd39 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -17,7 +17,6 @@ class ShapeAnalysis(Forward[Shape]): def initialize(self): super().initialize() - print(self.registry.statements) return self # Take a page from const prop in Kirin, From 7af6fb0766899f45ea4a1ab52bb3e67a9ea94ada Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 14:20:03 -0400 Subject: [PATCH 11/18] remove incorrect comment --- src/bloqade/squin/analysis/shape/analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/shape/analysis.py index 0e2fdd39..d157ccfc 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/shape/analysis.py @@ -29,7 +29,6 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): if method is not None: return method(self, frame, stmt) elif stmt.has_trait(HasSize): - # Caution! This can return None has_size_inst = stmt.get_trait(HasSize) size = has_size_inst.get_size(stmt) return (OpShape(size=size),) From accb46041f35e465c54112db48ff0486ee83763a Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 9 Apr 2025 22:27:59 -0400 Subject: [PATCH 12/18] rename to use sites instead of shape --- squin_op_playground.py | 4 +- src/bloqade/analysis/address/analysis.py | 1 - .../analysis/{shape => nsites}/__init__.py | 2 +- .../analysis/{shape => nsites}/analysis.py | 36 +++++----- src/bloqade/squin/analysis/nsites/impls.py | 69 +++++++++++++++++++ src/bloqade/squin/analysis/nsites/lattice.py | 49 +++++++++++++ src/bloqade/squin/analysis/shape/impls.py | 68 ------------------ src/bloqade/squin/analysis/shape/lattice.py | 49 ------------- src/bloqade/squin/op/stmts.py | 14 ++-- src/bloqade/squin/op/traits.py | 16 ++--- 10 files changed, 152 insertions(+), 156 deletions(-) rename src/bloqade/squin/analysis/{shape => nsites}/__init__.py (59%) rename src/bloqade/squin/analysis/{shape => nsites}/analysis.py (56%) create mode 100644 src/bloqade/squin/analysis/nsites/impls.py create mode 100644 src/bloqade/squin/analysis/nsites/lattice.py delete mode 100644 src/bloqade/squin/analysis/shape/impls.py delete mode 100644 src/bloqade/squin/analysis/shape/lattice.py diff --git a/squin_op_playground.py b/squin_op_playground.py index bdcb5128..1235585a 100644 --- a/squin_op_playground.py +++ b/squin_op_playground.py @@ -3,7 +3,7 @@ from kirin.dialects import py, func from bloqade import squin -from bloqade.squin.analysis import shape +from bloqade.squin.analysis import nsites def as_int(value: int): @@ -51,7 +51,7 @@ def as_int(value: int): constructed_method.print(analysis=address_frame.entries) """ -shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis( +shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( constructed_method, no_raise=False ) diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 22d97eeb..d1438b00 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -22,7 +22,6 @@ class AddressAnalysis(Forward[Address]): def initialize(self): super().initialize() self.next_address: int = 0 - print(self.registry.statements) return self @property diff --git a/src/bloqade/squin/analysis/shape/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py similarity index 59% rename from src/bloqade/squin/analysis/shape/__init__.py rename to src/bloqade/squin/analysis/nsites/__init__.py index 08c8fa45..c6b10354 100644 --- a/src/bloqade/squin/analysis/shape/__init__.py +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -1,3 +1,3 @@ # Need this for impl registration to work properly! from . import impls as impls -from .analysis import ShapeAnalysis as ShapeAnalysis +from .analysis import NSitesAnalysis as NSitesAnalysis diff --git a/src/bloqade/squin/analysis/shape/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py similarity index 56% rename from src/bloqade/squin/analysis/shape/analysis.py rename to src/bloqade/squin/analysis/nsites/analysis.py index d157ccfc..57f89566 100644 --- a/src/bloqade/squin/analysis/shape/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -5,19 +5,15 @@ from kirin.analysis.forward import ForwardFrame from bloqade.squin.op.types import OpType -from bloqade.squin.op.traits import Sized, HasSize +from bloqade.squin.op.traits import Sites, HasNSitesTrait -from .lattice import Shape, NoShape, OpShape +from .lattice import NSites, NoSites, HasNSites -class ShapeAnalysis(Forward[Shape]): +class NSitesAnalysis(Forward[NSites]): - keys = ["op.shape"] - lattice = Shape - - def initialize(self): - super().initialize() - return self + keys = ["op.nsites"] + lattice = NSites # Take a page from const prop in Kirin, # I can get the data I want from the SizedTrait @@ -28,20 +24,20 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): method = self.lookup_registry(frame, stmt) if method is not None: return method(self, frame, stmt) - elif stmt.has_trait(HasSize): - has_size_inst = stmt.get_trait(HasSize) - size = has_size_inst.get_size(stmt) - return (OpShape(size=size),) - elif stmt.has_trait(Sized): - size = stmt.get_trait(Sized) - return (OpShape(size=size.data),) + elif stmt.has_trait(HasNSitesTrait): + has_n_sites_trait = stmt.get_trait(HasNSitesTrait) + sites = has_n_sites_trait.get_sites(stmt) + return (HasNSites(sites=sites),) + elif stmt.has_trait(Sites): + sites_trait = stmt.get_trait(Sites) + return (HasNSites(sites=sites_trait.data),) else: - return (NoShape(),) + return (NoSites(),) # For when no implementation is found for the statement def eval_stmt_fallback( - self, frame: ForwardFrame[Shape], stmt: ir.Statement - ) -> tuple[Shape, ...]: # some form of Shape will go back into the frame + self, frame: ForwardFrame[NSites], stmt: ir.Statement + ) -> tuple[NSites, ...]: # some form of Shape will go back into the frame return tuple( ( self.lattice.top() @@ -51,6 +47,6 @@ def eval_stmt_fallback( for result in stmt.results ) - def run_method(self, method: ir.Method, args: tuple[Shape, ...]): + def run_method(self, method: ir.Method, args: tuple[NSites, ...]): # NOTE: we do not support dynamic calls here, thus no need to propagate method object return self.run_callable(method.code, (self.lattice.bottom(),) + args) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py new file mode 100644 index 00000000..5f9a9fe8 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -0,0 +1,69 @@ +from typing import cast + +from kirin import ir, interp + +from bloqade.squin import op + +from .lattice import ( + NoSites, + HasNSites, +) +from .analysis import NSitesAnalysis + + +@op.dialect.register(key="op.nsites") +class SquinOp(interp.MethodTable): + + @interp.impl(op.stmts.Kron) + def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites): + new_n_sites = lhs.sites + rhs.sites + return (HasNSites(sites=new_n_sites),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Mult) + def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + + if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites): + lhs_sites = lhs.sites + rhs_sites = rhs.sites + # I originally considered throwing an exception here + # but Xiu-zhe (Roger) Luo has pointed out it would be + # a much better UX to add a type element that + # could explicitly indicate the error. The downside + # is you'll have some added complexity in the type lattice. + if lhs_sites != rhs_sites: + return (NoSites(),) + else: + return (HasNSites(sites=lhs_sites + rhs_sites),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Control) + def control( + self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control + ): + op_sites = frame.get(stmt.op) + + if isinstance(op_sites, HasNSites): + n_sites = op_sites.sites + n_controls_attr = stmt.get_attr_or_prop("n_controls") + n_controls = cast(ir.PyAttr[int], n_controls_attr).data + return (HasNSites(sites=n_sites + n_controls),) + else: + return (NoSites(),) + + @interp.impl(op.stmts.Rot) + def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): + op_sites = frame.get(stmt.axis) + return (op_sites,) + + @interp.impl(op.stmts.Scale) + def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): + op_sites = frame.get(stmt.op) + return (op_sites,) diff --git a/src/bloqade/squin/analysis/nsites/lattice.py b/src/bloqade/squin/analysis/nsites/lattice.py new file mode 100644 index 00000000..44e10c25 --- /dev/null +++ b/src/bloqade/squin/analysis/nsites/lattice.py @@ -0,0 +1,49 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class NSites( + SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"] +): + @classmethod + def bottom(cls) -> "NSites": + return NoSites() + + @classmethod + def top(cls) -> "NSites": + return AnySites() + + +@final +@dataclass +class NoSites(NSites, metaclass=SingletonMeta): + + def is_subseteq(self, other: NSites) -> bool: + return True + + +@final +@dataclass +class AnySites(NSites, metaclass=SingletonMeta): + + def is_subseteq(self, other: NSites) -> bool: + return isinstance(other, NSites) + + +@final +@dataclass +class HasNSites(NSites): + sites: int + + def is_subseteq(self, other: NSites) -> bool: + if isinstance(other, HasNSites): + return self.sites == other.sites + return False diff --git a/src/bloqade/squin/analysis/shape/impls.py b/src/bloqade/squin/analysis/shape/impls.py deleted file mode 100644 index 878d3a7d..00000000 --- a/src/bloqade/squin/analysis/shape/impls.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import cast - -from kirin import ir, interp - -from bloqade.squin import op - -from .lattice import ( - NoShape, - OpShape, -) -from .analysis import ShapeAnalysis - - -@op.dialect.register(key="op.shape") -class SquinOp(interp.MethodTable): - - @interp.impl(op.stmts.Kron) - def kron(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): - lhs = frame.get(stmt.lhs) - rhs = frame.get(stmt.rhs) - if isinstance(lhs, OpShape) and isinstance(rhs, OpShape): - new_size = lhs.size + rhs.size - return (OpShape(size=new_size),) - else: - return (NoShape(),) - - @interp.impl(op.stmts.Mult) - def mult(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Mult): - lhs = frame.get(stmt.lhs) - rhs = frame.get(stmt.rhs) - - if isinstance(lhs, OpShape) and isinstance(rhs, OpShape): - lhs_size = lhs.size - rhs_size = rhs.size - # Sized trait implicitly enforces that - # all operators are square matrices, - # not sure if it's worth raising an exception here - # or just letting this propagate... - if lhs_size != rhs_size: - return (NoShape(),) - else: - return (OpShape(size=lhs_size + rhs_size),) - else: - return (NoShape(),) - - @interp.impl(op.stmts.Control) - def control( - self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Control - ): - op_shape = frame.get(stmt.op) - - if isinstance(op_shape, OpShape): - op_size = op_shape.size - n_controls_attr = stmt.get_attr_or_prop("n_controls") - n_controls = cast(ir.PyAttr[int], n_controls_attr).data - return (OpShape(size=op_size + n_controls),) - else: - return (NoShape(),) - - @interp.impl(op.stmts.Rot) - def rot(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Rot): - op_shape = frame.get(stmt.axis) - return (op_shape,) - - @interp.impl(op.stmts.Scale) - def scale(self, interp: ShapeAnalysis, frame: interp.Frame, stmt: op.stmts.Scale): - op_shape = frame.get(stmt.op) - return (op_shape,) diff --git a/src/bloqade/squin/analysis/shape/lattice.py b/src/bloqade/squin/analysis/shape/lattice.py deleted file mode 100644 index afd35251..00000000 --- a/src/bloqade/squin/analysis/shape/lattice.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import final -from dataclasses import dataclass - -from kirin.lattice import ( - SingletonMeta, - BoundedLattice, - SimpleJoinMixin, - SimpleMeetMixin, -) - - -@dataclass -class Shape( - SimpleJoinMixin["Shape"], SimpleMeetMixin["Shape"], BoundedLattice["Shape"] -): - @classmethod - def bottom(cls) -> "Shape": - return NoShape() - - @classmethod - def top(cls) -> "Shape": - return AnyShape() - - -@final -@dataclass -class NoShape(Shape, metaclass=SingletonMeta): - - def is_subseteq(self, other: Shape) -> bool: - return True - - -@final -@dataclass -class AnyShape(Shape, metaclass=SingletonMeta): - - def is_subseteq(self, other: Shape) -> bool: - return isinstance(other, Shape) - - -@final -@dataclass -class OpShape(Shape): - size: int - - def is_subseteq(self, other: Shape) -> bool: - if isinstance(other, OpShape): - return self.size == other.size - return False diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 09fb6052..3615e30c 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .traits import Sized, HasSize, Unitary, MaybeUnitary +from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait from .complex import Complex from ._dialect import dialect @@ -77,21 +77,21 @@ class Rot(CompositeOp): @statement(dialect=dialect) class Identity(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()}) - size: int = info.attribute() + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()}) + sites: int = info.attribute() result: ir.ResultValue = info.result(OpType) @statement class ConstantOp(PrimitiveOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)}) result: ir.ResultValue = info.result(OpType) @statement class ConstantUnitary(ConstantOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)} + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)} ) @@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index 28ea65f0..f242e522 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -5,22 +5,22 @@ @dataclass(frozen=True) -class Sized(ir.StmtTrait): +class Sites(ir.StmtTrait): data: int @dataclass(frozen=True) -class HasSize(ir.StmtTrait): - """An operator with a `size` attribute.""" +class HasNSitesTrait(ir.StmtTrait): + """An operator with a `sites` attribute.""" - def get_size(self, stmt: ir.Statement): - attr = stmt.get_attr_or_prop("size") + def get_sites(self, stmt: ir.Statement): + attr = stmt.get_attr_or_prop("sites") if attr is None: - raise ValueError(f"Missing size attribute in {stmt}") + raise ValueError(f"Missing sites attribute in {stmt}") return cast(ir.PyAttr[int], attr).data - def set_size(self, stmt: ir.Statement, value: int): - stmt.attributes["size"] = ir.PyAttr(value) + def set_sites(self, stmt: ir.Statement, value: int): + stmt.attributes["sites"] = ir.PyAttr(value) return From c0a0593abc48b20ec8523f7857bd872fa238a08b Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 10:21:19 -0400 Subject: [PATCH 13/18] add unit tests, implement name changes suggested by Roger --- src/bloqade/squin/analysis/nsites/__init__.py | 5 + src/bloqade/squin/analysis/nsites/analysis.py | 18 +- src/bloqade/squin/analysis/nsites/lattice.py | 22 +- src/bloqade/squin/op/stmts.py | 10 +- src/bloqade/squin/op/traits.py | 2 +- test/squin/analysis/test_nsites_analysis.py | 225 ++++++++++++++++++ 6 files changed, 256 insertions(+), 26 deletions(-) create mode 100644 test/squin/analysis/test_nsites_analysis.py diff --git a/src/bloqade/squin/analysis/nsites/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py index c6b10354..47f71d0c 100644 --- a/src/bloqade/squin/analysis/nsites/__init__.py +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -1,3 +1,8 @@ # Need this for impl registration to work properly! from . import impls as impls +from .lattice import ( + NoSites as NoSites, + AnySites as AnySites, + HasNSites as HasNSites, +) from .analysis import NSitesAnalysis as NSitesAnalysis diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index 57f89566..bc0b850c 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -5,15 +5,15 @@ from kirin.analysis.forward import ForwardFrame from bloqade.squin.op.types import OpType -from bloqade.squin.op.traits import Sites, HasNSitesTrait +from bloqade.squin.op.traits import NSites, HasNSitesTrait -from .lattice import NSites, NoSites, HasNSites +from .lattice import Sites, NoSites, HasNSites -class NSitesAnalysis(Forward[NSites]): +class NSitesAnalysis(Forward[Sites]): keys = ["op.nsites"] - lattice = NSites + lattice = Sites # Take a page from const prop in Kirin, # I can get the data I want from the SizedTrait @@ -28,16 +28,16 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): has_n_sites_trait = stmt.get_trait(HasNSitesTrait) sites = has_n_sites_trait.get_sites(stmt) return (HasNSites(sites=sites),) - elif stmt.has_trait(Sites): - sites_trait = stmt.get_trait(Sites) + elif stmt.has_trait(NSites): + sites_trait = stmt.get_trait(NSites) return (HasNSites(sites=sites_trait.data),) else: return (NoSites(),) # For when no implementation is found for the statement def eval_stmt_fallback( - self, frame: ForwardFrame[NSites], stmt: ir.Statement - ) -> tuple[NSites, ...]: # some form of Shape will go back into the frame + self, frame: ForwardFrame[Sites], stmt: ir.Statement + ) -> tuple[Sites, ...]: # some form of Shape will go back into the frame return tuple( ( self.lattice.top() @@ -47,6 +47,6 @@ def eval_stmt_fallback( for result in stmt.results ) - def run_method(self, method: ir.Method, args: tuple[NSites, ...]): + def run_method(self, method: ir.Method, args: tuple[Sites, ...]): # NOTE: we do not support dynamic calls here, thus no need to propagate method object return self.run_callable(method.code, (self.lattice.bottom(),) + args) diff --git a/src/bloqade/squin/analysis/nsites/lattice.py b/src/bloqade/squin/analysis/nsites/lattice.py index 44e10c25..e0bf56e0 100644 --- a/src/bloqade/squin/analysis/nsites/lattice.py +++ b/src/bloqade/squin/analysis/nsites/lattice.py @@ -10,40 +10,40 @@ @dataclass -class NSites( - SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"] +class Sites( + SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"] ): @classmethod - def bottom(cls) -> "NSites": + def bottom(cls) -> "Sites": return NoSites() @classmethod - def top(cls) -> "NSites": + def top(cls) -> "Sites": return AnySites() @final @dataclass -class NoSites(NSites, metaclass=SingletonMeta): +class NoSites(Sites, metaclass=SingletonMeta): - def is_subseteq(self, other: NSites) -> bool: + def is_subseteq(self, other: Sites) -> bool: return True @final @dataclass -class AnySites(NSites, metaclass=SingletonMeta): +class AnySites(Sites, metaclass=SingletonMeta): - def is_subseteq(self, other: NSites) -> bool: - return isinstance(other, NSites) + def is_subseteq(self, other: Sites) -> bool: + return isinstance(other, Sites) @final @dataclass -class HasNSites(NSites): +class HasNSites(Sites): sites: int - def is_subseteq(self, other: NSites) -> bool: + def is_subseteq(self, other: Sites) -> bool: if isinstance(other, HasNSites): return self.sites == other.sites return False diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 3615e30c..4f2e6784 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait +from .traits import NSites, Unitary, MaybeUnitary, HasNSitesTrait from .complex import Complex from ._dialect import dialect @@ -84,14 +84,14 @@ class Identity(CompositeOp): @statement class ConstantOp(PrimitiveOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), NSites(1)}) result: ir.ResultValue = info.result(OpType) @statement class ConstantUnitary(ConstantOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)} + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), NSites(1)} ) @@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index f242e522..ff57ee71 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -5,7 +5,7 @@ @dataclass(frozen=True) -class Sites(ir.StmtTrait): +class NSites(ir.StmtTrait): data: int diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py new file mode 100644 index 00000000..5fd083d3 --- /dev/null +++ b/test/squin/analysis/test_nsites_analysis.py @@ -0,0 +1,225 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.dialects import py, func + +from bloqade import squin +from bloqade.squin.analysis import nsites + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + squin_with_py = squin.groups.wired.add(py) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=squin.op.types.OpType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=squin_with_py, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(squin_with_py) + fold_pass(constructed_method) + + return constructed_method + + +def test_primitive_ops(): + pass + + +# Kron, Mult, Control, Rot, and Scale all have methods defined for handling them in impls, +# The following should ensure the code paths are properly exercised + + +def test_control(): + # Control doesn't have an impl but it is handled in the eval_stmt of the interpreter + # because it has a HasNSitesTrait future statements might have + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (controlled_h := squin.op.stmts.Control(op=h0.result, n_controls=1)), + (func.Return(controlled_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 2 + + +def test_kron(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + (hh := squin.op.stmts.Kron(h0.result, h1.result)), + (func.Return(hh.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 3 + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 3 + + +def test_mult_square_same_sites(): + # Ensure that two operators of the same size produce + # a valid operator as their result + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + (h2 := squin.op.stmts.Mult(h0.result, h1.result)), + (func.Return(h2.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + + # should be three HasNSites types + assert len(has_n_sites) == 3 + # the first 2 HasNSites will have 1 site but + # the Kron-produced operator should have 2 sites + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 2 + + +def test_mult_square_different_sites(): + # Ensure that two operators of different sizes produce + # NoSites as a type. Note that a better solution would be + # to implement a special error type in the type lattice + # but this would introduce some complexity later on + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (h1 := squin.op.stmts.H()), + # Kron to make nsites = 2 operator + (hh := squin.op.stmts.Kron(h0.result, h1.result)), + # apply Mult on HasNSites(2) and HasNSites(1) + (invalid_op := squin.op.stmts.Mult(hh.result, h1.result)), + (func.Return(invalid_op.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + nsites_types = list(nsites_frame.entries.values()) + + has_n_sites = [] + no_sites = [] + for nsite_type in nsites_types: + if isinstance(nsite_type, nsites.HasNSites): + has_n_sites.append(nsite_type) + elif isinstance(nsite_type, nsites.NoSites): + no_sites.append(nsite_type) + + assert len(has_n_sites) == 3 + # HasNSites(1) for Hadamards, 2 for Kron result + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + assert has_n_sites[2].sites == 2 + # One from function itself, another from invalid mult + assert len(no_sites) == 2 + + +def test_rot(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (angle := as_float(0.2)), + (rot_h := squin.op.stmts.Rot(axis=h0.result, angle=angle.result)), + (func.Return(rot_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + # Rot should just propagate whatever Sites type is there + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 + + +def test_scale(): + + stmts: list[ir.Statement] = [ + (h0 := squin.op.stmts.H()), + (factor := as_float(0.2)), + (rot_h := squin.op.stmts.Scale(op=h0.result, factor=factor.result)), + (func.Return(rot_h.result)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + + assert len(has_n_sites) == 2 + # Rot should just propagate whatever Sites type is there + assert has_n_sites[0].sites == 1 + assert has_n_sites[1].sites == 1 From 685e1c95948fe2abd38496f9a81757d0442809f2 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 10:24:24 -0400 Subject: [PATCH 14/18] fix improper kron test --- test/squin/analysis/test_nsites_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py index 5fd083d3..a44ea585 100644 --- a/test/squin/analysis/test_nsites_analysis.py +++ b/test/squin/analysis/test_nsites_analysis.py @@ -98,7 +98,7 @@ def test_kron(): assert len(has_n_sites) == 3 assert has_n_sites[0].sites == 1 assert has_n_sites[1].sites == 1 - assert has_n_sites[2].sites == 3 + assert has_n_sites[2].sites == 2 def test_mult_square_same_sites(): From 22fc9e76d330158175ffcfc4453a3b58f805e097 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 10:37:06 -0400 Subject: [PATCH 15/18] complete primitive op test --- test/squin/analysis/test_nsites_analysis.py | 33 ++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py index a44ea585..e9b1c95c 100644 --- a/test/squin/analysis/test_nsites_analysis.py +++ b/test/squin/analysis/test_nsites_analysis.py @@ -42,7 +42,38 @@ def gen_func_from_stmts(stmts): def test_primitive_ops(): - pass + # test a couple standard operators derived from PrimitiveOp + + stmts = [ + (n_qubits := as_int(1)), + (qreg := squin.qubit.New(n_qubits=n_qubits.result)), + (idx0 := as_int(0)), + (q := py.GetItem(qreg.result, idx0.result)), + # get wire + (w := squin.wire.Unwrap(q.result)), + # put wire through gates + (h := squin.op.stmts.H()), + (t := squin.op.stmts.T()), + (x := squin.op.stmts.X()), + (v0 := squin.wire.Apply(h.result, w.result)), + (v1 := squin.wire.Apply(t.result, v0.results[0])), + (v2 := squin.wire.Apply(x.result, v1.results[0])), + (func.Return(v2.results[0])), + ] + + constructed_method = gen_func_from_stmts(stmts) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + has_n_sites = [] + for nsites_type in nsites_frame.entries.values(): + if isinstance(nsites_type, nsites.HasNSites): + has_n_sites.append(nsites_type) + assert nsites_type.sites == 1 + + assert len(has_n_sites) == 3 # Kron, Mult, Control, Rot, and Scale all have methods defined for handling them in impls, From 39fdc84c7b7e01e1e0fec17494aaf8dca26e0749 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 11:45:28 -0400 Subject: [PATCH 16/18] Rename trait and type in lattice to avoid collision --- src/bloqade/squin/analysis/nsites/__init__.py | 2 +- src/bloqade/squin/analysis/nsites/analysis.py | 12 ++++++------ src/bloqade/squin/analysis/nsites/impls.py | 14 +++++++------- src/bloqade/squin/analysis/nsites/lattice.py | 4 ++-- src/bloqade/squin/op/stmts.py | 4 ++-- src/bloqade/squin/op/traits.py | 2 +- test/squin/analysis/test_nsites_analysis.py | 14 +++++++------- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py index 47f71d0c..da0a8e86 100644 --- a/src/bloqade/squin/analysis/nsites/__init__.py +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -3,6 +3,6 @@ from .lattice import ( NoSites as NoSites, AnySites as AnySites, - HasNSites as HasNSites, + NumberSites as NumberSites, ) from .analysis import NSitesAnalysis as NSitesAnalysis diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index bc0b850c..82e2b159 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -5,9 +5,9 @@ from kirin.analysis.forward import ForwardFrame from bloqade.squin.op.types import OpType -from bloqade.squin.op.traits import NSites, HasNSitesTrait +from bloqade.squin.op.traits import NSites, HasSites -from .lattice import Sites, NoSites, HasNSites +from .lattice import Sites, NoSites, NumberSites class NSitesAnalysis(Forward[Sites]): @@ -24,13 +24,13 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): method = self.lookup_registry(frame, stmt) if method is not None: return method(self, frame, stmt) - elif stmt.has_trait(HasNSitesTrait): - has_n_sites_trait = stmt.get_trait(HasNSitesTrait) + elif stmt.has_trait(HasSites): + has_n_sites_trait = stmt.get_trait(HasSites) sites = has_n_sites_trait.get_sites(stmt) - return (HasNSites(sites=sites),) + return (NumberSites(sites=sites),) elif stmt.has_trait(NSites): sites_trait = stmt.get_trait(NSites) - return (HasNSites(sites=sites_trait.data),) + return (NumberSites(sites=sites_trait.data),) else: return (NoSites(),) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 5f9a9fe8..3a4f94f1 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -6,7 +6,7 @@ from .lattice import ( NoSites, - HasNSites, + NumberSites, ) from .analysis import NSitesAnalysis @@ -18,9 +18,9 @@ class SquinOp(interp.MethodTable): def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron): lhs = frame.get(stmt.lhs) rhs = frame.get(stmt.rhs) - if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites): + if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): new_n_sites = lhs.sites + rhs.sites - return (HasNSites(sites=new_n_sites),) + return (NumberSites(sites=new_n_sites),) else: return (NoSites(),) @@ -29,7 +29,7 @@ def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult) lhs = frame.get(stmt.lhs) rhs = frame.get(stmt.rhs) - if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites): + if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites): lhs_sites = lhs.sites rhs_sites = rhs.sites # I originally considered throwing an exception here @@ -40,7 +40,7 @@ def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult) if lhs_sites != rhs_sites: return (NoSites(),) else: - return (HasNSites(sites=lhs_sites + rhs_sites),) + return (NumberSites(sites=lhs_sites + rhs_sites),) else: return (NoSites(),) @@ -50,11 +50,11 @@ def control( ): op_sites = frame.get(stmt.op) - if isinstance(op_sites, HasNSites): + if isinstance(op_sites, NumberSites): n_sites = op_sites.sites n_controls_attr = stmt.get_attr_or_prop("n_controls") n_controls = cast(ir.PyAttr[int], n_controls_attr).data - return (HasNSites(sites=n_sites + n_controls),) + return (NumberSites(sites=n_sites + n_controls),) else: return (NoSites(),) diff --git a/src/bloqade/squin/analysis/nsites/lattice.py b/src/bloqade/squin/analysis/nsites/lattice.py index e0bf56e0..cf11ef72 100644 --- a/src/bloqade/squin/analysis/nsites/lattice.py +++ b/src/bloqade/squin/analysis/nsites/lattice.py @@ -40,10 +40,10 @@ def is_subseteq(self, other: Sites) -> bool: @final @dataclass -class HasNSites(Sites): +class NumberSites(Sites): sites: int def is_subseteq(self, other: Sites) -> bool: - if isinstance(other, HasNSites): + if isinstance(other, NumberSites): return self.sites == other.sites return False diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 4f2e6784..ecbadce3 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .traits import NSites, Unitary, MaybeUnitary, HasNSitesTrait +from .traits import NSites, Unitary, HasSites, MaybeUnitary from .complex import Complex from ._dialect import dialect @@ -77,7 +77,7 @@ class Rot(CompositeOp): @statement(dialect=dialect) class Identity(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSites()}) sites: int = info.attribute() result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index ff57ee71..ede734ea 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -10,7 +10,7 @@ class NSites(ir.StmtTrait): @dataclass(frozen=True) -class HasNSitesTrait(ir.StmtTrait): +class HasSites(ir.StmtTrait): """An operator with a `sites` attribute.""" def get_sites(self, stmt: ir.Statement): diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py index e9b1c95c..19caa3e0 100644 --- a/test/squin/analysis/test_nsites_analysis.py +++ b/test/squin/analysis/test_nsites_analysis.py @@ -69,7 +69,7 @@ def test_primitive_ops(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) assert nsites_type.sites == 1 @@ -98,7 +98,7 @@ def test_control(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) assert len(has_n_sites) == 2 @@ -123,7 +123,7 @@ def test_kron(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) assert len(has_n_sites) == 3 @@ -151,7 +151,7 @@ def test_mult_square_same_sites(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) # should be three HasNSites types @@ -190,7 +190,7 @@ def test_mult_square_different_sites(): has_n_sites = [] no_sites = [] for nsite_type in nsites_types: - if isinstance(nsite_type, nsites.HasNSites): + if isinstance(nsite_type, nsites.NumberSites): has_n_sites.append(nsite_type) elif isinstance(nsite_type, nsites.NoSites): no_sites.append(nsite_type) @@ -221,7 +221,7 @@ def test_rot(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) assert len(has_n_sites) == 2 @@ -247,7 +247,7 @@ def test_scale(): has_n_sites = [] for nsites_type in nsites_frame.entries.values(): - if isinstance(nsites_type, nsites.HasNSites): + if isinstance(nsites_type, nsites.NumberSites): has_n_sites.append(nsites_type) assert len(has_n_sites) == 2 From 6f7c4b46fabf2a503af56c0c35e705c3e426b707 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 11:51:55 -0400 Subject: [PATCH 17/18] Try to choose trait names to avoid collision + be clear --- src/bloqade/squin/analysis/nsites/analysis.py | 10 +++++----- src/bloqade/squin/op/stmts.py | 12 +++++++----- src/bloqade/squin/op/traits.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index 82e2b159..2c4cd054 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -5,7 +5,7 @@ from kirin.analysis.forward import ForwardFrame from bloqade.squin.op.types import OpType -from bloqade.squin.op.traits import NSites, HasSites +from bloqade.squin.op.traits import HasSites, FixedSites from .lattice import Sites, NoSites, NumberSites @@ -25,11 +25,11 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement): if method is not None: return method(self, frame, stmt) elif stmt.has_trait(HasSites): - has_n_sites_trait = stmt.get_trait(HasSites) - sites = has_n_sites_trait.get_sites(stmt) + has_sites_trait = stmt.get_trait(HasSites) + sites = has_sites_trait.get_sites(stmt) return (NumberSites(sites=sites),) - elif stmt.has_trait(NSites): - sites_trait = stmt.get_trait(NSites) + elif stmt.has_trait(FixedSites): + sites_trait = stmt.get_trait(FixedSites) return (NumberSites(sites=sites_trait.data),) else: return (NoSites(),) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index ecbadce3..dcffd93a 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .traits import NSites, Unitary, HasSites, MaybeUnitary +from .traits import Unitary, HasSites, FixedSites, MaybeUnitary from .complex import Complex from ._dialect import dialect @@ -84,14 +84,16 @@ class Identity(CompositeOp): @statement class ConstantOp(PrimitiveOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), NSites(1)}) + traits = frozenset( + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), FixedSites(1)} + ) result: ir.ResultValue = info.result(OpType) @statement class ConstantUnitary(ConstantOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), NSites(1)} + {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), FixedSites(1)} ) @@ -105,7 +107,7 @@ class PhaseOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -120,7 +122,7 @@ class ShiftOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)}) + traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index ede734ea..506fdbf2 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -5,7 +5,7 @@ @dataclass(frozen=True) -class NSites(ir.StmtTrait): +class FixedSites(ir.StmtTrait): data: int From 2da07d316907e1e0895f5a20b9e17a93b5048fed Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 10 Apr 2025 14:12:17 -0400 Subject: [PATCH 18/18] Remove playground file --- squin_op_playground.py | 59 ------------------------------------------ 1 file changed, 59 deletions(-) delete mode 100644 squin_op_playground.py diff --git a/squin_op_playground.py b/squin_op_playground.py deleted file mode 100644 index 1235585a..00000000 --- a/squin_op_playground.py +++ /dev/null @@ -1,59 +0,0 @@ -from kirin import ir, types -from kirin.passes import Fold -from kirin.dialects import py, func - -from bloqade import squin -from bloqade.squin.analysis import nsites - - -def as_int(value: int): - return py.constant.Constant(value=value) - - -squin_with_qasm_core = squin.groups.wired.add(py) - -stmts: list[ir.Statement] = [ - (h0 := squin.op.stmts.H()), - (h1 := squin.op.stmts.H()), - (hh := squin.op.stmts.Kron(lhs=h1.result, rhs=h0.result)), - (chh := squin.op.stmts.Control(hh.result, n_controls=1)), - (factor := as_int(1)), - (schh := squin.op.stmts.Scale(chh.result, factor=factor.result)), - (func.Return(schh.result)), -] - -block = ir.Block(stmts) -block.args.append_from(types.MethodType[[], types.NoneType], "main_self") -func_wrapper = func.Function( - sym_name="main", - signature=func.Signature(inputs=(), output=squin.op.types.OpType), - body=ir.Region(blocks=block), -) - -constructed_method = ir.Method( - mod=None, - py_func=None, - sym_name="main", - dialects=squin_with_qasm_core, - code=func_wrapper, - arg_names=[], -) - -fold_pass = Fold(squin_with_qasm_core) -fold_pass(constructed_method) - -"""" -address_frame, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False -) - - -constructed_method.print(analysis=address_frame.entries) -""" - -shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False -) - - -constructed_method.print(analysis=shape_frame.entries)