Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bloqade/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .groups import logical as logical
from . import logical as logical
4 changes: 4 additions & 0 deletions src/bloqade/gemini/logical/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import stmts as stmts
from .groups import kernel as kernel
from ._dialect import dialect as dialect
from ._interface import terminal_measure as terminal_measure
3 changes: 3 additions & 0 deletions src/bloqade/gemini/logical/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("gemini.logical")
30 changes: 30 additions & 0 deletions src/bloqade/gemini/logical/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import TypeVar

from kirin import lowering
from kirin.dialects import ilist

from bloqade.types import Qubit, MeasurementResult

from .stmts import TerminalLogicalMeasurement

Len = TypeVar("Len", bound=int)
CodeN = TypeVar("CodeN", bound=int)


@lowering.wraps(TerminalLogicalMeasurement)
def terminal_measure(
qubits: ilist.IList[Qubit, Len],
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
"""Perform measurements on a list of logical qubits.

Measurements are returned as a nested list where each member list
contains the individual measurement results for the constituent physical qubits per logical qubit.

Args:
qubits (IList[Qubit, Len]): The list of logical qubits to measure.

Returns:
IList[IList[MeasurementResult, CodeN], Len]: A nested list containing the measurement results,
where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit.
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from bloqade.validation import KernelValidation
from bloqade.rewrite.passes import AggressiveUnroll

from .analysis import GeminiLogicalValidationAnalysis
from ..analysis import GeminiLogicalValidationAnalysis


@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
def logical(self):
def kernel(self):
"""Compile a function to a Gemini logical kernel."""

def run_pass(
Expand Down
19 changes: 19 additions & 0 deletions src/bloqade/gemini/logical/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from kirin import ir, types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

from bloqade.types import QubitType, MeasurementResultType

from ._dialect import dialect

Len = types.TypeVar("Len", bound=types.Int)
CodeN = types.TypeVar("CodeN", bound=types.Int)


@statement(dialect=dialect)
class TerminalLogicalMeasurement(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
result: ir.ResultValue = info.result(
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
)
18 changes: 9 additions & 9 deletions test/gemini/test_logical_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_if_stmt_invalid():
@gemini.logical(verify=False)
@gemini.logical.kernel(verify=False)
def main():
q = squin.qalloc(3)

Expand Down Expand Up @@ -42,7 +42,7 @@ def main():

def test_for_loop():

@gemini.logical
@gemini.logical.kernel
def valid_loop():
q = squin.qalloc(3)

Expand All @@ -53,7 +53,7 @@ def valid_loop():

with pytest.raises(ir.ValidationError):

@gemini.logical
@gemini.logical.kernel
def invalid_loop(n: int):
q = squin.qalloc(3)

Expand All @@ -64,11 +64,11 @@ def invalid_loop(n: int):


def test_func():
@gemini.logical
@gemini.logical.kernel
def sub_kernel(q: Qubit):
squin.x(q)

@gemini.logical
@gemini.logical.kernel
def main():
q = squin.qalloc(3)
sub_kernel(q[0])
Expand All @@ -77,14 +77,14 @@ def main():

with pytest.raises(ValidationErrorGroup):

@gemini.logical(inline=False)
@gemini.logical.kernel(inline=False)
def invalid():
q = squin.qalloc(3)
sub_kernel(q[0])


def test_clifford_gates():
@gemini.logical
@gemini.logical.kernel
def main():
q = squin.qalloc(2)
squin.u3(0.123, 0.253, 1.2, q[0])
Expand All @@ -94,7 +94,7 @@ def main():

with pytest.raises(ir.ValidationError):

@gemini.logical(no_raise=False)
@gemini.logical.kernel(no_raise=False)
def invalid():
q = squin.qalloc(2)

Expand All @@ -113,7 +113,7 @@ def test_multiple_errors():
did_error = False
try:

@gemini.logical
@gemini.logical.kernel
def main(n: int):
q = squin.qalloc(3)
m = squin.qubit.measure(q[0])
Expand Down
Loading