Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/bloqade/analysis/address/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NotQubit as NotQubit,
AddressReg as AddressReg,
AnyAddress as AnyAddress,
AddressWire as AddressWire,
AddressQubit as AddressQubit,
AddressTuple as AddressTuple,
)
Expand Down
4 changes: 3 additions & 1 deletion src/bloqade/analysis/address/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@


class AddressAnalysis(Forward[Address]):
"""This analysis pass can be used to track the global addresses of qubits."""
"""
This analysis pass can be used to track the global addresses of qubits and wires.
"""

keys = ["qubit.address"]
lattice = Address
Expand Down
81 changes: 80 additions & 1 deletion src/bloqade/analysis/address/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from kirin.analysis import ForwardFrame, const
from kirin.dialects import cf, py, scf, func, ilist

from .lattice import Address, NotQubit, AddressReg, AddressQubit, AddressTuple
from bloqade import squin

from .lattice import (
Address,
NotQubit,
AddressReg,
AddressWire,
AddressQubit,
AddressTuple,
)
from .analysis import AddressAnalysis


Expand Down Expand Up @@ -64,10 +73,16 @@
class PyIndexing(interp.MethodTable):
@interp.impl(py.GetItem)
def getitem(self, interp: AddressAnalysis, frame: interp.Frame, stmt: py.GetItem):
# Integer index into the thing being indexed
idx = interp.get_const_value(int, stmt.index)
# The object being indexed into
obj = frame.get(stmt.obj)
# The `data` attributes holds onto other Address types
# so we just extract that here
if isinstance(obj, AddressTuple):
return (obj.data[idx],)
# an AddressReg is guaranteed to just have some sequence
# of integers which is directly pluggable to AddressQubit
elif isinstance(obj, AddressReg):
return (AddressQubit(obj.data[idx]),)
else:
Expand Down Expand Up @@ -147,3 +162,67 @@
return # if terminate is Return, there is no result

return loop_vars


# Address lattice elements we can work with:
## NotQubit (bottom), AnyAddress (top)

## AddressTuple -> data: tuple[Address, ...]
### Recursive type, could contain itself or other variants
### This pops up in cases where you can have an IList/Tuple
### That contains elements that could be other Address types

## AddressReg -> data: Sequence[int]
### specific to creation of a register of qubits

## AddressQubit -> data: int
### Base qubit address type


@squin.wire.dialect.register(key="qubit.address")
class SquinWireMethodTable(interp.MethodTable):

@interp.impl(squin.wire.Unwrap)
def unwrap(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: squin.wire.Unwrap,
):

origin_qubit = frame.get(stmt.qubit)

return (AddressWire(origin_qubit=origin_qubit),)

@interp.impl(squin.wire.Apply)
def apply(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: squin.wire.Apply,
):

origin_qubits = tuple(
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
)
new_address_wires = tuple(
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
)
return new_address_wires


@squin.qubit.dialect.register(key="qubit.address")
class SquinQubitMethodTable(interp.MethodTable):

# This can be treated like a QRegNew impl
@interp.impl(squin.qubit.New)
def new(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: squin.qubit.New,
):
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
interp_.next_address += n_qubits
return (addr,)

Check warning on line 228 in src/bloqade/analysis/address/impls.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/address/impls.py#L225-L228

Added lines #L225 - L228 were not covered by tests
11 changes: 11 additions & 0 deletions src/bloqade/analysis/address/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,14 @@
if isinstance(other, AddressQubit):
return self.data == other.data
return False


@final
@dataclass
class AddressWire(Address):
origin_qubit: AddressQubit

def is_subseteq(self, other: Address) -> bool:
if isinstance(other, AddressWire):
return self.origin_qubit == self.origin_qubit
return False

Check warning on line 85 in src/bloqade/analysis/address/lattice.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/analysis/address/lattice.py#L83-L85

Added lines #L83 - L85 were not covered by tests
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from bloqade.qasm2.parse import ast
from bloqade.qasm2.types import QRegType
from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
from bloqade.analysis.schedule import DagScheduleAnalysis
from bloqade.squin.analysis.schedule import DagScheduleAnalysis

dialect = ir.Dialect("qasm2.glob")

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bloqade.qasm2.parse import ast
from bloqade.qasm2.types import QubitType
from bloqade.qasm2.emit.gate import EmitQASM2Gate, EmitQASM2Frame
from bloqade.analysis.schedule import DagScheduleAnalysis
from bloqade.squin.analysis.schedule import DagScheduleAnalysis

dialect = ir.Dialect("qasm2.parallel")

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/uop/schedule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from kirin import interp
from kirin.analysis import ForwardFrame

from bloqade.analysis.schedule import DagScheduleAnalysis
from bloqade.squin.analysis.schedule import DagScheduleAnalysis

from . import stmts
from ._dialect import dialect
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/qasm2/passes/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
)
from kirin.analysis import const

from bloqade.analysis import address, schedule
from bloqade.analysis import address
from bloqade.qasm2.rewrite import (
MergePolicyABC,
ParallelToUOpRule,
RaiseRegisterRule,
UOpToParallelRule,
SimpleOptimalMergePolicy,
)
from bloqade.squin.analysis import schedule


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/rewrite/uop_to_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from bloqade.analysis import address
from bloqade.qasm2.dialects import uop, core, parallel
from bloqade.analysis.schedule import StmtDag
from bloqade.squin.analysis.schedule import StmtDag


class MergePolicyABC(abc.ABC):
Expand Down
Empty file.
6 changes: 6 additions & 0 deletions src/bloqade/squin/op/complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Stopgap Measure, squin dialect needs Complex type but
# this is only available in Kirin 0.15.x

from kirin.ir.attrs.types import PyClass

Complex = PyClass(complex)
3 changes: 2 additions & 1 deletion src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .types import OpType
from .traits import Sized, HasSize, Unitary, MaybeUnitary
from .complex import Complex
from ._dialect import dialect


Expand Down Expand Up @@ -53,7 +54,7 @@ class Scale(CompositeOp):
traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()})
is_unitary: bool = info.attribute(default=False)
op: ir.SSAValue = info.argument(OpType)
factor: ir.SSAValue = info.argument(types.Complex)
factor: ir.SSAValue = info.argument(Complex)
result: ir.ResultValue = info.result(OpType)


Expand Down
34 changes: 28 additions & 6 deletions src/bloqade/squin/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
dialect.
"""

from kirin import ir, types
from kirin import ir, types, interp
from kirin.decl import info, statement

from bloqade.types import QubitType

from .op.types import OpType

# from kirin.lowering import wraps

# from .op.types import Op, OpType

dialect = ir.Dialect("squin.wire")


Expand All @@ -35,26 +39,33 @@ class Wrap(ir.Statement):
qubit: ir.SSAValue = info.argument(QubitType)


# "Unwrap the quantum references to expose wires" -> From Quake Dialect documentation
# Unwrap(Qubit) -> Wire
@statement(dialect=dialect)
class Unwrap(ir.Statement):
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
qubit: ir.SSAValue = info.argument(QubitType)
result: ir.ResultValue = info.result(WireType)


# In Quake, you put a wire in and get a wire out when you "apply" an operator
# In this case though we just need to indicate that an operator is applied to list[wires]
@statement(dialect=dialect)
class Apply(ir.Statement):
class Apply(ir.Statement): # apply(op, w1, w2, ...)
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
operator: ir.SSAValue = info.argument(OpType)
inputs: tuple[ir.SSAValue] = info.argument(WireType)
inputs: tuple[ir.SSAValue, ...] = info.argument(WireType)

def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue):
result_types = tuple(WireType for _ in args)
super().__init__(
args=(operator,) + args,
result_types=result_types,
args_slice={"operator": 0, "inputs": slice(1, None)},
)
result_types=result_types, # result types of the Apply statement, should all be WireTypes
args_slice={
"operator": 0,
"inputs": slice(1, None),
}, # pretty printing + syntax sugar
) # custom lowering required for wrapper to work here


# NOTE: measurement cannot be pure because they will collapse the state
Expand All @@ -79,3 +90,14 @@ class MeasureAndReset(ir.Statement):
class Reset(ir.Statement):
traits = frozenset({ir.FromPythonCall(), WireTerminator()})
wire: ir.SSAValue = info.argument(WireType)


# Issue where constant propagation can't handle
# multiple return values from Apply properly
@dialect.register(key="constprop")
class ConstPropWire(interp.MethodTable):

@interp.impl(Apply)
def apply(self, interp, frame, stmt: Apply):

return frame.get_values(stmt.inputs)
Loading
Loading