Skip to content

Commit b31a011

Browse files
authored
Refactor Logical moves on Gemini stdlib. (#85)
* adding new logical moves * adding new logical moves * adding tests of individual logical moves * adding docs + fixing tests * Renaming variable * putting more move logic into tweezer kernel * refactor tweezer implementations * Fixing name * Adding tests * adding move kernels that wrap the tweezer kernels
1 parent 6c4f71c commit b31a011

File tree

2 files changed

+359
-122
lines changed

2 files changed

+359
-122
lines changed

src/bloqade/shuttle/stdlib/layouts/gemini/logical.py

Lines changed: 203 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
from typing import TypeVar
1+
from typing import Any
22

33
from bloqade.geometry.dialects import grid
44
from kirin.dialects import ilist
55

6-
from bloqade.shuttle import action, gate, schedule, spec, tweezer
6+
from bloqade.shuttle import action, schedule, spec, tweezer
77
from bloqade.shuttle.prelude import move
88

99
from ..asserts import assert_sorted
1010
from .base_spec import get_base_spec
1111

1212

1313
def get_spec():
14+
"""Get the architecture specification for the Gemini logical qubit layout.
15+
16+
Returns:
17+
ArchSpec: The architecture specification with Gemini logical qubit layout.
18+
19+
"""
1420
arch_spec = get_base_spec()
1521

1622
gate_zone = arch_spec.layout.static_traps["gate_zone"]
@@ -39,11 +45,14 @@ def get_spec():
3945
ML1_block = M1_block[::2, :]
4046
MR1_block = M1_block[1::2, :]
4147

42-
GL0_block = left_traps[2 : 2 + 7 :, :]
43-
GR0_block = right_traps[2 : 2 + 7 :, :]
48+
GL_blocks = left_traps[2 : 2 + 2 * 7, :]
49+
GR_blocks = right_traps[2 : 2 + 2 * 7, :]
50+
51+
GL0_block = GL_blocks[:7, :]
52+
GR0_block = GR_blocks[:7, :]
4453

45-
GL1_block = left_traps[2 + 7 : 2 + 2 * 7 :, :]
46-
GR1_block = right_traps[2 + 7 : 2 + 2 * 7 :, :]
54+
GL1_block = GL_blocks[7 : 2 * 7, :]
55+
GR1_block = GR_blocks[7 : 2 * 7, :]
4756

4857
AOM0_block = aom_sites[2 : 2 + 7, :]
4958
AOM1_block = aom_sites[2 + 7 : 2 + 7 + 7, :]
@@ -53,6 +62,8 @@ def get_spec():
5362
"right_gate_zone_sites": right_traps,
5463
"top_reservoir_sites": top_reservoir,
5564
"bottom_reservoir_sites": bottom_reservoir,
65+
"GL_blocks": GL_blocks,
66+
"GR_blocks": GR_blocks,
5667
"GL0_block": GL0_block,
5768
"GL1_block": GL1_block,
5869
"GR0_block": GR0_block,
@@ -79,7 +90,7 @@ def get_spec():
7990
("GL0_block", "GL1_block", "GR0_block", "GR1_block")
8091
)
8192

82-
logical_rows, _ = SL0_block.shape
93+
_, logical_rows = GL0_block.shape
8394
logical_cols = 2
8495
code_size = 7
8596

@@ -94,66 +105,211 @@ def get_spec():
94105
return arch_spec
95106

96107

97-
N = TypeVar("N")
108+
@tweezer
109+
def move_by_shift(
110+
start_pos: grid.Grid[Any, Any],
111+
shifts: ilist.IList[tuple[float, float], Any],
112+
active_x: ilist.IList[int, Any],
113+
active_y: ilist.IList[int, Any],
114+
):
115+
"""Moves the specified atoms by applying a series of shifts.
116+
117+
Args:
118+
start_pos (grid.Grid[Any, Any]): The starting position of the atoms.
119+
shifts (ilist.IList[tuple[float, float], Any]): The list of shifts to apply.
120+
active_x (ilist.IList[int, Any]): The list of active x indices of start_pos.
121+
active_y (ilist.IList[int, Any]): The list of active y indices of start_pos.
122+
"""
123+
action.set_loc(start_pos)
124+
action.turn_on(active_x, active_y)
125+
126+
current_pos = start_pos
127+
for shift in shifts:
128+
current_pos = grid.shift(current_pos, shift[0], shift[1])
129+
action.move(current_pos)
130+
131+
action.turn_off(active_x, active_y)
98132

99133

100134
@tweezer
101-
def ltor_block_aom_move(
102-
left_subblocks: ilist.IList[int, N],
103-
right_subblocks: ilist.IList[int, N],
135+
def get_block(
136+
block_id: str,
137+
col_index: int,
138+
row_indices: ilist.IList[int, Any],
104139
):
105-
assert_sorted(left_subblocks)
106-
assert_sorted(right_subblocks)
140+
"""Returns the zone corresponding to the specified block and column.
141+
142+
Args:
143+
block_id (str): The block identifier, either "GL" or "GR".
144+
col_index (int): The logical column index.
145+
row_indices (IList[int, Any]): The list of logical row indices.
146+
147+
Returns:
148+
Grid: The grid corresponding to the specified block and column.
149+
150+
"""
151+
assert block_id in ("GL", "GR"), "block_id must be either 'GL' or 'GR'"
107152

108-
assert len(left_subblocks) == len(
109-
right_subblocks
110-
), "Left and right subblocks must have the same length."
153+
block = None
154+
if block_id == "GL":
155+
block = spec.get_static_trap(zone_id="GL_blocks")
156+
elif block_id == "GR":
157+
block = spec.get_static_trap(zone_id="GR_blocks")
111158

112-
left_blocks = spec.get_static_trap(zone_id="GL0_block")
113-
right_blocks = spec.get_special_grid(grid_id="AOM1_block")
159+
code_size = spec.get_int_constant(constant_id="code_size")
160+
return block[col_index * code_size : (col_index + 1) * code_size, row_indices]
114161

115-
left_block = left_blocks[:, left_subblocks]
116-
right_block = right_blocks[:, right_subblocks]
162+
163+
@tweezer
164+
def calc_vertical_shifts(
165+
offset: int,
166+
):
167+
"""Generates a list of shifts to move atoms vertically by the specified offset.
168+
169+
Args:
170+
offset (int): The offset to apply to the row indices, must be non-negative.
171+
"""
117172
row_separation = spec.get_float_constant(constant_id="row_separation")
118173
col_separation = spec.get_float_constant(constant_id="col_separation")
119174
gate_spacing = spec.get_float_constant(constant_id="gate_spacing")
120175

121-
# AOM sites are already shifted by the gate spacing, so to shift to the center between the
122-
# two blocks, we need to shift the AOM sites by half the col separation minus the gate
123-
# spacing.
124-
shift_from_aom = col_separation / 2.0 - gate_spacing
125-
third_pos = grid.shift(right_block, -shift_from_aom, 0.0)
126-
first_pos = grid.shift(left_block, 0.0, row_separation / 2.0)
127-
second_pos = grid.from_positions(grid.get_xpos(third_pos), grid.get_ypos(first_pos))
176+
sign = 1
177+
if offset < 0:
178+
sign = -1
179+
offset = -offset
180+
181+
shifts = ilist.IList([])
182+
if offset > 1:
183+
shifts = ilist.IList(
184+
[
185+
(0.0, row_separation * 0.5),
186+
(gate_spacing + col_separation * 0.5, 0.0),
187+
(0.0, row_separation * (offset - 0.5)),
188+
(-col_separation * 0.5, 0.0),
189+
]
190+
)
191+
elif offset == 1:
192+
shifts = ilist.IList(
193+
[
194+
(0.0, row_separation * 0.5),
195+
(gate_spacing, 0.0),
196+
(0.0, row_separation * 0.5),
197+
]
198+
)
199+
else:
200+
shifts = ilist.IList([(gate_spacing, 0.0)])
201+
202+
def multiple_sign(idx: int):
203+
return (shifts[idx][0], sign * shifts[idx][1])
204+
205+
return ilist.map(multiple_sign, ilist.range(len(shifts)))
206+
207+
208+
@tweezer
209+
def vertical_shift_impl(
210+
offset: int,
211+
src_col: int,
212+
src_rows: ilist.IList[int, Any],
213+
):
214+
"""Moves the specified rows within the given block.
215+
216+
Args:
217+
offset (int): The offset to apply to the row indices, must be non-negative.
218+
src_col (int): The source column index.
219+
src_rows (ilist.IList[int, Any]): The list of source row indices.
220+
"""
221+
222+
def check_row(row: int):
223+
num_rows = spec.get_int_constant(constant_id="logical_rows")
224+
assert (
225+
row + offset < num_rows
226+
), "row index + offset must be less than `logical_rows`"
227+
assert row + offset >= 0, "row index + offset must be non-negative"
128228

129-
action.set_loc(left_block)
130-
action.turn_on(action.ALL, action.ALL)
131-
action.move(first_pos)
132-
action.move(second_pos)
133-
action.move(third_pos)
134-
action.move(right_block)
229+
ilist.for_each(check_row, src_rows)
230+
231+
assert_sorted(src_rows)
232+
233+
row_start = 0
234+
row_end = spec.get_int_constant(constant_id="logical_rows")
235+
236+
if offset > 0:
237+
row_end = row_end - offset
238+
else:
239+
row_start = row_start - offset
240+
241+
start_pos = get_block("GL", src_col, ilist.range(row_start, row_end))
242+
shape = grid.shape(start_pos)
243+
all_cols = ilist.range(shape[0])
244+
245+
shifts = calc_vertical_shifts(offset)
246+
247+
move_by_shift(start_pos, shifts, all_cols, src_rows)
135248

136249

137250
@move
138-
def get_device_fn(
139-
left_subblocks: ilist.IList[int, N],
140-
right_subblocks: ilist.IList[int, N],
251+
def vertical_shift(
252+
offset: int,
253+
src_col: int,
254+
src_rows: ilist.IList[int, Any],
141255
):
256+
"""Moves the specified rows within the given block.
257+
258+
Args:
259+
offset (int): The offset to apply to the row indices, must be non-negative.
260+
src_col (int): The source column index.
261+
src_rows (ilist.IList[int, Any]): The list of source row indices.
262+
"""
263+
142264
x_tones = ilist.range(spec.get_int_constant(constant_id="code_size"))
143-
y_tones = ilist.range(len(left_subblocks))
144-
return schedule.device_fn(ltor_block_aom_move, x_tones, y_tones)
265+
y_tones = ilist.range(len(src_rows))
266+
267+
device_fn = schedule.device_fn(vertical_shift_impl, x_tones, y_tones)
268+
device_fn(offset, src_col, src_rows)
269+
270+
271+
@tweezer
272+
def gr_zero_to_one_impl(
273+
src_rows: ilist.IList[int, Any],
274+
):
275+
"""Moves the specified columns within the given block.
276+
277+
Args:
278+
src_rows (ilist.IList[int, Any]): The rows to apply the transformation to.
279+
"""
280+
logical_rows = spec.get_int_constant(constant_id="logical_rows")
281+
row_separation = spec.get_float_constant(constant_id="row_separation")
282+
col_separation = spec.get_float_constant(constant_id="col_separation")
283+
shift = col_separation * spec.get_float_constant(constant_id="code_size")
284+
285+
shifts = ilist.IList(
286+
[
287+
(0.0, row_separation * 0.5),
288+
(shift, 0.0),
289+
(0.0, -row_separation * 0.5),
290+
]
291+
)
292+
293+
all_rows = ilist.range(logical_rows)
294+
start_pos = get_block("GR", 0, all_rows)
295+
296+
shape = grid.shape(start_pos)
297+
all_cols = ilist.range(shape[0])
298+
299+
move_by_shift(start_pos, shifts, all_cols, src_rows)
145300

146301

147302
@move
148-
def entangle(
149-
left_subblocks: ilist.IList[int, N],
150-
right_subblocks: ilist.IList[int, N],
303+
def gr_zero_to_one(
304+
src_rows: ilist.IList[int, Any],
151305
):
306+
"""Moves the specified columns within the given block.
152307
153-
device_func = get_device_fn(left_subblocks, right_subblocks)
154-
rev_func = schedule.reverse(device_func)
308+
Args:
309+
src_rows (ilist.IList[int, Any]): The rows to apply the transformation to.
310+
"""
311+
x_tones = ilist.range(spec.get_int_constant(constant_id="code_size"))
312+
y_tones = ilist.range(len(src_rows))
155313

156-
if len(left_subblocks) > 0:
157-
device_func(left_subblocks, right_subblocks)
158-
gate.top_hat_cz(spec.get_static_trap(zone_id="gate_zone"))
159-
rev_func(left_subblocks, right_subblocks)
314+
device_fn = schedule.device_fn(gr_zero_to_one_impl, x_tones, y_tones)
315+
device_fn(src_rows)

0 commit comments

Comments
 (0)