Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bloqade/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .groups import logical as logical
from .dialects import logical as logical
3 changes: 0 additions & 3 deletions src/bloqade/gemini/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .logical_validation.analysis import (
GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis,
)
3 changes: 2 additions & 1 deletion src/bloqade/gemini/analysis/logical_validation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import impls as impls, analysis as analysis # NOTE: register methods
from . import impls as impls # NOTE: register methods
from .analysis import GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis
1 change: 1 addition & 0 deletions src/bloqade/gemini/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import logical as logical
4 changes: 4 additions & 0 deletions src/bloqade/gemini/dialects/logical/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import stmts as stmts
from .groups import kernel as kernel
from ._dialect import dialect as dialect
from ._interface import terminal_measure as terminal_measure
3 changes: 3 additions & 0 deletions src/bloqade/gemini/dialects/logical/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("logical")
30 changes: 30 additions & 0 deletions src/bloqade/gemini/dialects/logical/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import TypeVar

from kirin import lowering
from kirin.dialects import ilist

from bloqade.types import Qubit, MeasurementResult

from .stmts import TerminalLogicalMeasurement

Len = TypeVar("Len", bound=int)
CodeN = TypeVar("CodeN", bound=int)


@lowering.wraps(TerminalLogicalMeasurement)
def terminal_measure(
qubits: ilist.IList[Qubit, Len],
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
"""Perform measurements on a list of logical qubits.

Measurements are returned as a nested list where each member list
contains the individual measurement results for the constituent physical qubits per logical qubit.

Args:
qubits (IList[Qubit, Len]): The list of logical qubits to measure.

Returns:
IList[IList[MeasurementResult, CodeN], Len]: A nested list containing the measurement results,
where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit.
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from bloqade.squin import gate, qubit
from bloqade.validation import KernelValidation
from bloqade.rewrite.passes import AggressiveUnroll
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis

from .analysis import GeminiLogicalValidationAnalysis
from ._dialect import dialect


@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
def logical(self):
@ir.dialect_group(
structural_no_opt.union([gate, py.constant, qubit, func, ilist, dialect])
)
def kernel(self):
"""Compile a function to a Gemini logical kernel."""

def run_pass(
Expand Down
32 changes: 32 additions & 0 deletions src/bloqade/gemini/dialects/logical/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from kirin import ir, types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

from bloqade.types import QubitType, MeasurementResultType

from ._dialect import dialect

Len = types.TypeVar("Len", bound=types.Int)
CodeN = types.TypeVar("CodeN", bound=types.Int)


@statement(dialect=dialect)
class TerminalLogicalMeasurement(ir.Statement):
"""Perform measurements on a list of logical qubits.

Measurements are returned as a nested list where each member list
contains the individual measurement results for the constituent physical qubits per logical qubit.

Args:
qubits (IList[QubitType, Len]): The list of logical qubits

Returns:
IList[IList[MeasurementResultType, CodeN], Len]: A nested list containing the measurement results,
where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit.
"""

traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
result: ir.ResultValue = info.result(
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
)
46 changes: 36 additions & 10 deletions test/gemini/test_logical_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from bloqade import squin, gemini
from bloqade.types import Qubit
from bloqade.validation import KernelValidation
from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis
from bloqade.validation.kernel_validation import ValidationErrorGroup
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis


def test_if_stmt_invalid():
@gemini.logical(verify=False)
@gemini.logical.kernel(verify=False)
def main():
q = squin.qalloc(3)

Expand Down Expand Up @@ -42,7 +42,7 @@ def main():

def test_for_loop():

@gemini.logical
@gemini.logical.kernel
def valid_loop():
q = squin.qalloc(3)

Expand All @@ -53,7 +53,7 @@ def valid_loop():

with pytest.raises(ir.ValidationError):

@gemini.logical
@gemini.logical.kernel
def invalid_loop(n: int):
q = squin.qalloc(3)

Expand All @@ -64,11 +64,11 @@ def invalid_loop(n: int):


def test_func():
@gemini.logical
@gemini.logical.kernel
def sub_kernel(q: Qubit):
squin.x(q)

@gemini.logical
@gemini.logical.kernel
def main():
q = squin.qalloc(3)
sub_kernel(q[0])
Expand All @@ -77,14 +77,14 @@ def main():

with pytest.raises(ValidationErrorGroup):

@gemini.logical(inline=False)
@gemini.logical.kernel(inline=False)
def invalid():
q = squin.qalloc(3)
sub_kernel(q[0])


def test_clifford_gates():
@gemini.logical
@gemini.logical.kernel
def main():
q = squin.qalloc(2)
squin.u3(0.123, 0.253, 1.2, q[0])
Expand All @@ -94,7 +94,7 @@ def main():

with pytest.raises(ir.ValidationError):

@gemini.logical(no_raise=False)
@gemini.logical.kernel(no_raise=False)
def invalid():
q = squin.qalloc(2)

Expand All @@ -109,11 +109,37 @@ def invalid():
invalid.print(analysis=frame.entries)


def test_terminal_measurement():
@gemini.logical.kernel(verify=False)
def main():
q = squin.qalloc(3)
m = gemini.logical.terminal_measure(q)
return m

main.print()

with pytest.raises(ir.ValidationError):

@gemini.logical.kernel(no_raise=False)
def invalid():
q = squin.qalloc(3)
squin.x(q[0])
m = gemini.logical.terminal_measure(q)
another_m = gemini.logical.terminal_measure(q)
return m, another_m

frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
invalid
)

invalid.print(analysis=frame.entries)


def test_multiple_errors():
did_error = False
try:

@gemini.logical
@gemini.logical.kernel
def main(n: int):
q = squin.qalloc(3)
m = squin.qubit.measure(q[0])
Expand Down
Loading