Skip to content

Commit cfc7f69

Browse files
johnzl-777david-pl
andauthored
TerminalLogicalMeasurement implementation (#627)
Co-authored-by: David Plankensteiner <[email protected]>
1 parent d29a1c3 commit cfc7f69

File tree

10 files changed

+115
-18
lines changed

10 files changed

+115
-18
lines changed

src/bloqade/gemini/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .groups import logical as logical
1+
from .dialects import logical as logical
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from .logical_validation.analysis import (
2-
GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis,
3-
)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from . import impls as impls, analysis as analysis # NOTE: register methods
1+
from . import impls as impls # NOTE: register methods
2+
from .analysis import GeminiLogicalValidationAnalysis as GeminiLogicalValidationAnalysis
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import logical as logical
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import stmts as stmts
2+
from .groups import kernel as kernel
3+
from ._dialect import dialect as dialect
4+
from ._interface import terminal_measure as terminal_measure
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("logical")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import TypeVar
2+
3+
from kirin import lowering
4+
from kirin.dialects import ilist
5+
6+
from bloqade.types import Qubit, MeasurementResult
7+
8+
from .stmts import TerminalLogicalMeasurement
9+
10+
Len = TypeVar("Len", bound=int)
11+
CodeN = TypeVar("CodeN", bound=int)
12+
13+
14+
@lowering.wraps(TerminalLogicalMeasurement)
15+
def terminal_measure(
16+
qubits: ilist.IList[Qubit, Len],
17+
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
18+
"""Perform measurements on a list of logical qubits.
19+
20+
Measurements are returned as a nested list where each member list
21+
contains the individual measurement results for the constituent physical qubits per logical qubit.
22+
23+
Args:
24+
qubits (IList[Qubit, Len]): The list of logical qubits to measure.
25+
26+
Returns:
27+
IList[IList[MeasurementResult, CodeN], Len]: A nested list containing the measurement results,
28+
where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit.
29+
"""
30+
...

src/bloqade/gemini/groups.py renamed to src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
from bloqade.squin import gate, qubit
1111
from bloqade.validation import KernelValidation
1212
from bloqade.rewrite.passes import AggressiveUnroll
13+
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis
1314

14-
from .analysis import GeminiLogicalValidationAnalysis
15+
from ._dialect import dialect
1516

1617

17-
@ir.dialect_group(structural_no_opt.union([gate, py.constant, qubit, func, ilist]))
18-
def logical(self):
18+
@ir.dialect_group(
19+
structural_no_opt.union([gate, py.constant, qubit, func, ilist, dialect])
20+
)
21+
def kernel(self):
1922
"""Compile a function to a Gemini logical kernel."""
2023

2124
def run_pass(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from kirin import ir, types, lowering
2+
from kirin.decl import info, statement
3+
from kirin.dialects import ilist
4+
5+
from bloqade.types import QubitType, MeasurementResultType
6+
7+
from ._dialect import dialect
8+
9+
Len = types.TypeVar("Len", bound=types.Int)
10+
CodeN = types.TypeVar("CodeN", bound=types.Int)
11+
12+
13+
@statement(dialect=dialect)
14+
class TerminalLogicalMeasurement(ir.Statement):
15+
"""Perform measurements on a list of logical qubits.
16+
17+
Measurements are returned as a nested list where each member list
18+
contains the individual measurement results for the constituent physical qubits per logical qubit.
19+
20+
Args:
21+
qubits (IList[QubitType, Len]): The list of logical qubits
22+
23+
Returns:
24+
IList[IList[MeasurementResultType, CodeN], Len]: A nested list containing the measurement results,
25+
where each inner list corresponds to the measurements of the physical qubits that make up each logical qubit.
26+
"""
27+
28+
traits = frozenset({lowering.FromPythonCall()})
29+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
30+
result: ir.ResultValue = info.result(
31+
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
32+
)

test/gemini/test_logical_validation.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from bloqade import squin, gemini
55
from bloqade.types import Qubit
66
from bloqade.validation import KernelValidation
7-
from bloqade.gemini.analysis import GeminiLogicalValidationAnalysis
87
from bloqade.validation.kernel_validation import ValidationErrorGroup
8+
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidationAnalysis
99

1010

1111
def test_if_stmt_invalid():
12-
@gemini.logical(verify=False)
12+
@gemini.logical.kernel(verify=False)
1313
def main():
1414
q = squin.qalloc(3)
1515

@@ -42,7 +42,7 @@ def main():
4242

4343
def test_for_loop():
4444

45-
@gemini.logical
45+
@gemini.logical.kernel
4646
def valid_loop():
4747
q = squin.qalloc(3)
4848

@@ -53,7 +53,7 @@ def valid_loop():
5353

5454
with pytest.raises(ir.ValidationError):
5555

56-
@gemini.logical
56+
@gemini.logical.kernel
5757
def invalid_loop(n: int):
5858
q = squin.qalloc(3)
5959

@@ -64,11 +64,11 @@ def invalid_loop(n: int):
6464

6565

6666
def test_func():
67-
@gemini.logical
67+
@gemini.logical.kernel
6868
def sub_kernel(q: Qubit):
6969
squin.x(q)
7070

71-
@gemini.logical
71+
@gemini.logical.kernel
7272
def main():
7373
q = squin.qalloc(3)
7474
sub_kernel(q[0])
@@ -77,14 +77,14 @@ def main():
7777

7878
with pytest.raises(ValidationErrorGroup):
7979

80-
@gemini.logical(inline=False)
80+
@gemini.logical.kernel(inline=False)
8181
def invalid():
8282
q = squin.qalloc(3)
8383
sub_kernel(q[0])
8484

8585

8686
def test_clifford_gates():
87-
@gemini.logical
87+
@gemini.logical.kernel
8888
def main():
8989
q = squin.qalloc(2)
9090
squin.u3(0.123, 0.253, 1.2, q[0])
@@ -94,7 +94,7 @@ def main():
9494

9595
with pytest.raises(ir.ValidationError):
9696

97-
@gemini.logical(no_raise=False)
97+
@gemini.logical.kernel(no_raise=False)
9898
def invalid():
9999
q = squin.qalloc(2)
100100

@@ -109,11 +109,37 @@ def invalid():
109109
invalid.print(analysis=frame.entries)
110110

111111

112+
def test_terminal_measurement():
113+
@gemini.logical.kernel(verify=False)
114+
def main():
115+
q = squin.qalloc(3)
116+
m = gemini.logical.terminal_measure(q)
117+
return m
118+
119+
main.print()
120+
121+
with pytest.raises(ir.ValidationError):
122+
123+
@gemini.logical.kernel(no_raise=False)
124+
def invalid():
125+
q = squin.qalloc(3)
126+
squin.x(q[0])
127+
m = gemini.logical.terminal_measure(q)
128+
another_m = gemini.logical.terminal_measure(q)
129+
return m, another_m
130+
131+
frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
132+
invalid
133+
)
134+
135+
invalid.print(analysis=frame.entries)
136+
137+
112138
def test_multiple_errors():
113139
did_error = False
114140
try:
115141

116-
@gemini.logical
142+
@gemini.logical.kernel
117143
def main(n: int):
118144
q = squin.qalloc(3)
119145
m = squin.qubit.measure(q[0])

0 commit comments

Comments
 (0)