Skip to content

Commit 77dc4b0

Browse files
authored
implement vacant grid (#31)
* implement vacant grid * adding positions API * moving position of property * adding wrappers for filled * refactor is_equal * adding backup * fixing names * fixing comments from cduck * Fixing bug in interpreter and adding test for repeat * adding tests + fixing runtime
1 parent afd10d3 commit 77dc4b0

File tree

10 files changed

+416
-3
lines changed

10 files changed

+416
-3
lines changed

src/bloqade/shuttle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .dialects.action import _interface as action
22
from .dialects.atom import _interface as atom
3+
from .dialects.filled import _interface as filled
34
from .dialects.gate import _interface as gate
45
from .dialects.init import _interface as init
56
from .dialects.measure import _interface as measure
@@ -15,4 +16,5 @@
1516
"measure",
1617
"schedule",
1718
"spec",
19+
"filled",
1820
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ._dialect import dialect as dialect
2+
from ._interface import (
3+
fill as fill,
4+
get_parent as get_parent,
5+
vacate as vacate,
6+
)
7+
from .concrete import FilledGridMethods as FilledGridMethods
8+
from .stmts import Fill as Fill, GetParent as GetParent, Vacate as Vacate
9+
from .types import FilledGrid as FilledGrid, FilledGridType as FilledGridType
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("shuttle.filled")
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Any, TypeVar
2+
3+
from bloqade.geometry.dialects import grid
4+
from kirin.dialects import ilist
5+
from kirin.lowering import wraps as _wraps
6+
7+
from .stmts import Fill, GetParent, Vacate
8+
from .types import FilledGrid
9+
10+
Nx = TypeVar("Nx")
11+
Ny = TypeVar("Ny")
12+
13+
14+
@_wraps(Vacate)
15+
def vacate(
16+
zone: grid.Grid[Nx, Ny],
17+
vacancies: ilist.IList[tuple[int, int], Any],
18+
) -> FilledGrid[Nx, Ny]: ...
19+
20+
21+
@_wraps(Fill)
22+
def fill(
23+
zone: grid.Grid[Nx, Ny],
24+
filled: ilist.IList[tuple[int, int], Any],
25+
) -> FilledGrid[Nx, Ny]: ...
26+
27+
28+
@_wraps(GetParent)
29+
def get_parent(filled_grid: FilledGrid[Nx, Ny]) -> grid.Grid[Nx, Ny]: ...
30+
31+
32+
@_wraps(grid.Shift)
33+
def shift(
34+
grid: FilledGrid[Nx, Ny], x_shift: float, y_shift: float
35+
) -> FilledGrid[Nx, Ny]: ...
36+
37+
38+
@_wraps(grid.Scale)
39+
def scale(
40+
grid: FilledGrid[Nx, Ny], x_scale: float, y_scale: float
41+
) -> FilledGrid[Nx, Ny]: ...
42+
43+
44+
@_wraps(grid.Repeat)
45+
def repeat(
46+
grid: FilledGrid[Any, Any],
47+
x_times: int,
48+
y_times: int,
49+
y_spacing: float,
50+
x_spacing: float,
51+
) -> FilledGrid[Any, Any]: ...
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any
2+
3+
from bloqade.geometry.dialects.grid.types import Grid
4+
from kirin.dialects import ilist
5+
from kirin.interp import (
6+
Frame,
7+
Interpreter,
8+
MethodTable,
9+
impl,
10+
)
11+
12+
from . import stmts
13+
from ._dialect import dialect
14+
from .types import FilledGrid
15+
16+
17+
@dialect.register
18+
class FilledGridMethods(MethodTable):
19+
20+
@impl(stmts.Vacate)
21+
def vacate(self, interp: Interpreter, frame: Frame, stmt: stmts.Vacate):
22+
zone = frame.get_casted(stmt.zone, Grid)
23+
vacancies = frame.get_casted(stmt.vacancies, ilist.IList[tuple[int, int], Any])
24+
return (FilledGrid.vacate(zone, vacancies),)
25+
26+
@impl(stmts.Fill)
27+
def fill(self, interp: Interpreter, frame: Frame, stmt: stmts.Fill):
28+
zone = frame.get_casted(stmt.zone, Grid)
29+
filled = frame.get_casted(stmt.filled, ilist.IList[tuple[int, int], Any])
30+
return (FilledGrid.fill(zone, filled),)
31+
32+
@impl(stmts.GetParent)
33+
def get_parent(self, interp: Interpreter, frame: Frame, stmt: stmts.GetParent):
34+
filled_grid = frame.get_casted(stmt.filled_grid, FilledGrid)
35+
return (filled_grid.parent,)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from bloqade.geometry.dialects import grid
2+
from kirin import decl, ir, lowering, types
3+
from kirin.decl import info
4+
from kirin.dialects import ilist
5+
6+
from ._dialect import dialect
7+
from .types import FilledGridType
8+
9+
NumVacant = types.TypeVar("NumVacant")
10+
Nx = types.TypeVar("Nx")
11+
Ny = types.TypeVar("Ny")
12+
13+
14+
@decl.statement(dialect=dialect)
15+
class Vacate(ir.Statement):
16+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
17+
18+
zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
19+
vacancies: ir.SSAValue = info.argument(
20+
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
21+
)
22+
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])
23+
24+
25+
@decl.statement(dialect=dialect)
26+
class Fill(ir.Statement):
27+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
28+
29+
zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
30+
filled: ir.SSAValue = info.argument(
31+
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
32+
)
33+
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])
34+
35+
36+
@decl.statement(dialect=dialect)
37+
class GetParent(ir.Statement):
38+
name = "get_parent"
39+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
40+
41+
filled_grid: ir.SSAValue = info.argument(FilledGridType[Nx, Ny])
42+
result: ir.ResultValue = info.result(grid.GridType[Nx, Ny])
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from dataclasses import dataclass, field
2+
from functools import cached_property
3+
from itertools import product
4+
from typing import Any, Iterable, Sequence, TypeVar
5+
6+
from bloqade.geometry.dialects import grid
7+
from kirin import types
8+
from kirin.dialects import ilist
9+
10+
NumX = TypeVar("NumX")
11+
NumY = TypeVar("NumY")
12+
13+
14+
@dataclass(eq=False)
15+
class FilledGrid(grid.Grid[NumX, NumY]):
16+
x_spacing: tuple[float, ...] = field(init=False)
17+
y_spacing: tuple[float, ...] = field(init=False)
18+
x_init: float | None = field(init=False)
19+
y_init: float | None = field(init=False)
20+
21+
parent: grid.Grid[NumX, NumY]
22+
vacancies: frozenset[tuple[int, int]]
23+
24+
def __post_init__(self):
25+
self.x_spacing = self.parent.x_spacing
26+
self.y_spacing = self.parent.y_spacing
27+
self.x_init = self.parent.x_init
28+
self.y_init = self.parent.y_init
29+
30+
self.type = types.Generic(
31+
FilledGrid,
32+
types.Literal(len(self.x_spacing) + 1),
33+
types.Literal(len(self.y_spacing) + 1),
34+
)
35+
36+
def __hash__(self):
37+
return hash((FilledGrid, self.parent, self.vacancies))
38+
39+
def __eq__(self, other: Any) -> bool:
40+
return (
41+
isinstance(other, FilledGrid)
42+
and self.parent == other.parent
43+
and self.vacancies == other.vacancies
44+
)
45+
46+
@cached_property
47+
def positions(self) -> ilist.IList[tuple[float, float], Any]:
48+
positions = tuple(
49+
(x, y)
50+
for (ix, x), (iy, y) in product(
51+
enumerate(self.x_positions), enumerate(self.y_positions)
52+
)
53+
if (ix, iy) not in self.vacancies
54+
)
55+
56+
return ilist.IList(positions)
57+
58+
@classmethod
59+
def fill(
60+
cls, grid_obj: grid.Grid[NumX, NumY], filled: Sequence[tuple[int, int]]
61+
) -> "FilledGrid[NumX, NumY]":
62+
num_x, num_y = grid_obj.shape
63+
64+
if isinstance(grid_obj, FilledGrid):
65+
vacancies = grid_obj.vacancies
66+
else:
67+
vacancies = frozenset(product(range(num_x), range(num_y)))
68+
69+
vacancies = vacancies - frozenset(filled)
70+
71+
return cls(parent=grid_obj, vacancies=vacancies)
72+
73+
@classmethod
74+
def vacate(
75+
cls, grid_obj: grid.Grid[NumX, NumY], vacancies: Iterable[tuple[int, int]]
76+
) -> "FilledGrid[NumX, NumY]":
77+
78+
if isinstance(grid_obj, FilledGrid):
79+
input_vacancies = grid_obj.vacancies
80+
else:
81+
input_vacancies = frozenset()
82+
83+
input_vacancies = input_vacancies.union(vacancies)
84+
85+
return cls(parent=grid_obj, vacancies=input_vacancies)
86+
87+
def get_view( # type: ignore
88+
self, x_indices: ilist.IList[int, Any], y_indices: ilist.IList[int, Any]
89+
):
90+
remapping_x = {ix: i for i, ix in enumerate(x_indices)}
91+
remapping_y = {iy: i for i, iy in enumerate(y_indices)}
92+
return FilledGrid(
93+
parent=self.parent.get_view(x_indices, y_indices),
94+
vacancies=frozenset(
95+
(remapping_x[x], remapping_y[y])
96+
for x, y in self.vacancies
97+
if x in remapping_x and y in remapping_y
98+
),
99+
)
100+
101+
def shift(self, x_shift: float, y_shift: float):
102+
return FilledGrid(
103+
parent=self.parent.shift(x_shift, y_shift),
104+
vacancies=self.vacancies,
105+
)
106+
107+
def scale(self, x_scale: float, y_scale: float):
108+
return FilledGrid(
109+
parent=self.parent.scale(x_scale, y_scale),
110+
vacancies=self.vacancies,
111+
)
112+
113+
def repeat(self, x_times: int, y_times: int, x_gap: float, y_gap: float):
114+
new_parent = self.parent.repeat(x_times, y_times, x_gap, y_gap)
115+
x_dim, y_dim = self.shape
116+
vacancies = frozenset(
117+
(x + x_dim * i, y + y_dim * j)
118+
for i, j, (x, y) in product(range(x_times), range(y_times), self.vacancies)
119+
)
120+
return FilledGrid.vacate(new_parent, vacancies)
121+
122+
123+
FilledGridType = types.Generic(FilledGrid, types.TypeVar("NumX"), types.TypeVar("NumY"))

src/bloqade/shuttle/prelude.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from bloqade.shuttle.dialects import (
1212
action,
1313
atom,
14+
filled,
1415
gate,
1516
init,
1617
measure,
@@ -23,7 +24,7 @@
2324
from bloqade.shuttle.rewrite.desugar import DesugarTurnOffRewrite, DesugarTurnOnRewrite
2425

2526

26-
@ir.dialect_group(structural.union([spec, grid, atom, gate, op, qubit]))
27+
@ir.dialect_group(structural.union([spec, grid, filled, atom, gate, op, qubit]))
2728
def kernel(self):
2829
def run_pass(
2930
mt: ir.Method,
@@ -50,7 +51,7 @@ def run_pass(
5051

5152

5253
# We dont allow [cf, aod, schedule] appear in move function
53-
@ir.dialect_group(structural.union([action, spec, grid]))
54+
@ir.dialect_group(structural.union([action, spec, grid, filled]))
5455
def tweezer(self):
5556
fold_pass = Fold(self)
5657
typeinfer_pass = TypeInfer(self)
@@ -91,7 +92,7 @@ def run_pass(
9192

9293
# no action allow. can have cf, with addtional spec
9394
@ir.dialect_group(
94-
structural.union([init, schedule, path, grid, spec, gate, op, measure])
95+
structural.union([init, schedule, path, grid, filled, spec, gate, op, measure])
9596
)
9697
def move(self):
9798
schedule_to_path = ScheduleToPath(self)

test/filled/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)