diff --git a/src/bloqade/gemini/__init__.py b/src/bloqade/gemini/__init__.py index c03aeeb7..e5708280 100644 --- a/src/bloqade/gemini/__init__.py +++ b/src/bloqade/gemini/__init__.py @@ -1 +1 @@ -from .groups import logical as logical +from .dialects import logical as logical diff --git a/src/bloqade/gemini/analysis/__init__.py b/src/bloqade/gemini/analysis/__init__.py index 8c94d180..e69de29b 100644 --- a/src/bloqade/gemini/analysis/__init__.py +++ b/src/bloqade/gemini/analysis/__init__.py @@ -1,3 +0,0 @@ -from .logical_validation.analysis import ( - GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis, -) diff --git a/src/bloqade/gemini/analysis/logical_validation/__init__.py b/src/bloqade/gemini/analysis/logical_validation/__init__.py index 1b289d8c..dc0d77ef 100644 --- a/src/bloqade/gemini/analysis/logical_validation/__init__.py +++ b/src/bloqade/gemini/analysis/logical_validation/__init__.py @@ -1 +1,2 @@ -from . import impls as impls, analysis as analysis # NOTE: register methods +from . import impls as impls # NOTE: register methods +from .analysis import GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis diff --git a/src/bloqade/gemini/dialects/__init__.py b/src/bloqade/gemini/dialects/__init__.py new file mode 100644 index 00000000..96b315d7 --- /dev/null +++ b/src/bloqade/gemini/dialects/__init__.py @@ -0,0 +1 @@ +from . import logical as logical diff --git a/src/bloqade/gemini/dialects/logical/__init__.py b/src/bloqade/gemini/dialects/logical/__init__.py new file mode 100644 index 00000000..187ae361 --- /dev/null +++ b/src/bloqade/gemini/dialects/logical/__init__.py @@ -0,0 +1,4 @@ +from . import stmts as stmts +from .groups import kernel as kernel +from ._dialect import dialect as dialect +from ._interface import terminal_measure as terminal_measure diff --git a/src/bloqade/gemini/dialects/logical/_dialect.py b/src/bloqade/gemini/dialects/logical/_dialect.py new file mode 100644 index 00000000..6be7cae9 --- /dev/null +++ b/src/bloqade/gemini/dialects/logical/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("logical") diff --git a/src/bloqade/gemini/dialects/logical/_interface.py b/src/bloqade/gemini/dialects/logical/_interface.py new file mode 100644 index 00000000..f93b3b27 --- /dev/null +++ b/src/bloqade/gemini/dialects/logical/_interface.py @@ -0,0 +1,30 @@ +from typing import TypeVar + +from kirin import lowering +from kirin.dialects import ilist + +from bloqade.types import Qubit, MeasurementResult + +from .stmts import TerminalLogicalMeasurement + +Len = TypeVar("Len", bound=int) +CodeN = TypeVar("CodeN", bound=int) + + +@lowering.wraps(TerminalLogicalMeasurement) +def terminal_measure( + qubits: ilist.IList[Qubit, Len], +) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]: + """Perform measurements on a list of logical qubits. + + Measurements are returned as a nested list where each member list + contains the individual measurement results for the constituent physical qubits per logical qubit. + + Args: + qubits (IList[Qubit, Len]): The list of logical qubits to measure. + + Returns: + IList[IList[MeasurementResult, CodeN], Len]: A nested list containing the measurement results, + where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit. + """ + ... diff --git a/src/bloqade/gemini/groups.py b/src/bloqade/gemini/dialects/logical/groups.py similarity index 89% rename from src/bloqade/gemini/groups.py rename to src/bloqade/gemini/dialects/logical/groups.py index 90441099..dea43e91 100644 --- a/src/bloqade/gemini/groups.py +++ b/src/bloqade/gemini/dialects/logical/groups.py @@ -10,12 +10,15 @@ from bloqade.squin import gate, qubit from bloqade.validation import KernelValidation from bloqade.rewrite.passes import AggressiveUnroll +from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis -from .analysis import GeminiLogicalValidationAnalysis +from ._dialect import dialect -@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist])) -def logical(self): +@ir.dialect_group( + structural_no_opt.union([gate, py.constant, qubit, func, ilist, dialect]) +) +def kernel(self): """Compile a function to a Gemini logical kernel.""" def run_pass( diff --git a/src/bloqade/gemini/dialects/logical/stmts.py b/src/bloqade/gemini/dialects/logical/stmts.py new file mode 100644 index 00000000..e1a5684c --- /dev/null +++ b/src/bloqade/gemini/dialects/logical/stmts.py @@ -0,0 +1,32 @@ +from kirin import ir, types, lowering +from kirin.decl import info, statement +from kirin.dialects import ilist + +from bloqade.types import QubitType, MeasurementResultType + +from ._dialect import dialect + +Len = types.TypeVar("Len", bound=types.Int) +CodeN = types.TypeVar("CodeN", bound=types.Int) + + +@statement(dialect=dialect) +class TerminalLogicalMeasurement(ir.Statement): + """Perform measurements on a list of logical qubits. + + Measurements are returned as a nested list where each member list + contains the individual measurement results for the constituent physical qubits per logical qubit. + + Args: + qubits (IList[QubitType, Len]): The list of logical qubits + + Returns: + IList[IList[MeasurementResultType, CodeN], Len]: A nested list containing the measurement results, + where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit. + """ + + traits = frozenset({lowering.FromPythonCall()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) + result: ir.ResultValue = info.result( + ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len] + ) diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index 046b2ac4..34917778 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -4,12 +4,12 @@ from bloqade import squin, gemini from bloqade.types import Qubit from bloqade.validation import KernelValidation -from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis from bloqade.validation.kernel_validation import ValidationErrorGroup +from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis def test_if_stmt_invalid(): - @gemini.logical(verify=False) + @gemini.logical.kernel(verify=False) def main(): q = squin.qalloc(3) @@ -42,7 +42,7 @@ def main(): def test_for_loop(): - @gemini.logical + @gemini.logical.kernel def valid_loop(): q = squin.qalloc(3) @@ -53,7 +53,7 @@ def valid_loop(): with pytest.raises(ir.ValidationError): - @gemini.logical + @gemini.logical.kernel def invalid_loop(n: int): q = squin.qalloc(3) @@ -64,11 +64,11 @@ def invalid_loop(n: int): def test_func(): - @gemini.logical + @gemini.logical.kernel def sub_kernel(q: Qubit): squin.x(q) - @gemini.logical + @gemini.logical.kernel def main(): q = squin.qalloc(3) sub_kernel(q[0]) @@ -77,14 +77,14 @@ def main(): with pytest.raises(ValidationErrorGroup): - @gemini.logical(inline=False) + @gemini.logical.kernel(inline=False) def invalid(): q = squin.qalloc(3) sub_kernel(q[0]) def test_clifford_gates(): - @gemini.logical + @gemini.logical.kernel def main(): q = squin.qalloc(2) squin.u3(0.123, 0.253, 1.2, q[0]) @@ -94,7 +94,7 @@ def main(): with pytest.raises(ir.ValidationError): - @gemini.logical(no_raise=False) + @gemini.logical.kernel(no_raise=False) def invalid(): q = squin.qalloc(2) @@ -109,11 +109,37 @@ def invalid(): invalid.print(analysis=frame.entries) +def test_terminal_measurement(): + @gemini.logical.kernel(verify=False) + def main(): + q = squin.qalloc(3) + m = gemini.logical.terminal_measure(q) + return m + + main.print() + + with pytest.raises(ir.ValidationError): + + @gemini.logical.kernel(no_raise=False) + def invalid(): + q = squin.qalloc(3) + squin.x(q[0]) + m = gemini.logical.terminal_measure(q) + another_m = gemini.logical.terminal_measure(q) + return m, another_m + + frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise( + invalid + ) + + invalid.print(analysis=frame.entries) + + def test_multiple_errors(): did_error = False try: - @gemini.logical + @gemini.logical.kernel def main(n: int): q = squin.qalloc(3) m = squin.qubit.measure(q[0])