Skip to content

Commit 953b099

Browse files
authored
Adding filled dialect to represent grids with vacancies in it. (#41)
* updating to new kirin version + adding filled dialect * updating filled init and also adding/updating tests * Adding doc strings * Adding test * Updating documentation nav * Fixing potential bug in filled construction * review suggestions
1 parent 9952b46 commit 953b099

File tree

19 files changed

+557
-11
lines changed

19 files changed

+557
-11
lines changed

mkdocs.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ nav:
2020
- grid:
2121
- statements: reference/bloqade/geometry/dialects/grid/_interface.md
2222
- types: reference/bloqade/geometry/dialects/grid/types.md
23-
23+
- filled:
24+
- statements: reference/bloqade/geometry/dialects/filled/_interface.md
25+
- types: reference/bloqade/geometry/dialects/filled/types.md
2426

2527

2628
theme:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ authors = [
77
{ name = "Phillip Weinberg", email = "pweinberg@quera.com" }
88
]
99
dependencies = [
10-
"kirin-toolchain~=0.18.0",
10+
"kirin-toolchain~=0.22.0",
1111
]
1212
requires-python = ">= 3.10"
1313

src/bloqade/geometry/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1-
from .dialects import grid as grid
1+
from .dialects.filled import _interface as filled
2+
from .dialects.grid import _interface as grid
3+
4+
__all__ = [
5+
"grid",
6+
"filled",
7+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from ._dialect import dialect as dialect
2+
from ._interface import (
3+
fill as fill,
4+
get_parent as get_parent,
5+
vacate as vacate,
6+
)
7+
from .concrete import FilledGridMethods as FilledGridMethods
8+
from .stmts import Fill as Fill, GetParent as GetParent, Vacate as Vacate
9+
from .types import FilledGrid as FilledGrid, FilledGridType as FilledGridType
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("geometry.filled")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Any, TypeVar
2+
3+
from kirin.dialects import ilist
4+
from kirin.lowering import wraps as _wraps
5+
6+
from bloqade.geometry.dialects import grid
7+
8+
from .stmts import Fill, GetParent, Vacate
9+
from .types import FilledGrid
10+
11+
Nx = TypeVar("Nx")
12+
Ny = TypeVar("Ny")
13+
14+
15+
@_wraps(Vacate)
16+
def vacate(
17+
zone: grid.Grid[Nx, Ny],
18+
vacancies: ilist.IList[tuple[int, int], Any],
19+
) -> FilledGrid[Nx, Ny]:
20+
"""Create a FilledGrid by vacating specified positions from a grid.
21+
22+
Args:
23+
zone: The original grid from which positions will be vacated.
24+
vacancies: An IList of (x_index, y_index) tuples indicating positions to vacate
25+
26+
Returns:
27+
A FilledGrid with the specified vacancies.
28+
29+
"""
30+
...
31+
32+
33+
@_wraps(Fill)
34+
def fill(
35+
zone: grid.Grid[Nx, Ny],
36+
filled: ilist.IList[tuple[int, int], Any],
37+
) -> FilledGrid[Nx, Ny]:
38+
"""Create a FilledGrid by filling specified positions in a grid.
39+
40+
Args:
41+
zone: The original grid in which positions will be filled.
42+
filled: An IList of (x_index, y_index) tuples indicating positions to fill
43+
44+
Returns:
45+
A FilledGrid with the specified positions filled.
46+
47+
"""
48+
...
49+
50+
51+
@_wraps(GetParent)
52+
def get_parent(filled_grid: FilledGrid[Nx, Ny]) -> grid.Grid[Nx, Ny]:
53+
"""Retrieve the parent grid of a FilledGrid.
54+
55+
Args:
56+
filled_grid: The FilledGrid whose parent grid is to be retrieved.
57+
58+
Returns:
59+
The parent grid of the provided FilledGrid.
60+
61+
"""
62+
...
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Any
2+
3+
from kirin.dialects import ilist
4+
from kirin.interp import (
5+
Frame,
6+
Interpreter,
7+
MethodTable,
8+
impl,
9+
)
10+
11+
from bloqade.geometry.dialects.grid.types import Grid
12+
13+
from . import stmts
14+
from ._dialect import dialect
15+
from .types import FilledGrid
16+
17+
18+
@dialect.register
19+
class FilledGridMethods(MethodTable):
20+
21+
@impl(stmts.Vacate)
22+
def vacate(self, interp: Interpreter, frame: Frame, stmt: stmts.Vacate):
23+
zone = frame.get_casted(stmt.zone, Grid)
24+
vacancies = frame.get_casted(stmt.vacancies, ilist.IList[tuple[int, int], Any])
25+
return (FilledGrid.vacate(zone, vacancies),)
26+
27+
@impl(stmts.Fill)
28+
def fill(self, interp: Interpreter, frame: Frame, stmt: stmts.Fill):
29+
zone = frame.get_casted(stmt.zone, Grid)
30+
filled = frame.get_casted(stmt.filled, ilist.IList[tuple[int, int], Any])
31+
return (FilledGrid.fill(zone, filled),)
32+
33+
@impl(stmts.GetParent)
34+
def get_parent(self, interp: Interpreter, frame: Frame, stmt: stmts.GetParent):
35+
filled_grid = frame.get_casted(stmt.filled_grid, FilledGrid)
36+
return (filled_grid.parent,)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from kirin import decl, ir, lowering, types
2+
from kirin.decl import info
3+
from kirin.dialects import ilist
4+
5+
from bloqade.geometry.dialects import grid
6+
7+
from ._dialect import dialect
8+
from .types import FilledGridType
9+
10+
NumVacant = types.TypeVar("NumVacant")
11+
Nx = types.TypeVar("Nx")
12+
Ny = types.TypeVar("Ny")
13+
14+
15+
@decl.statement(dialect=dialect)
16+
class Vacate(ir.Statement):
17+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
18+
19+
zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
20+
vacancies: ir.SSAValue = info.argument(
21+
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
22+
)
23+
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])
24+
25+
26+
@decl.statement(dialect=dialect)
27+
class Fill(ir.Statement):
28+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
29+
30+
zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
31+
filled: ir.SSAValue = info.argument(
32+
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
33+
)
34+
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])
35+
36+
37+
@decl.statement(dialect=dialect)
38+
class GetParent(ir.Statement):
39+
name = "get_parent"
40+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
41+
42+
filled_grid: ir.SSAValue = info.argument(FilledGridType[Nx, Ny])
43+
result: ir.ResultValue = info.result(grid.GridType[Nx, Ny])
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from dataclasses import dataclass, field
2+
from functools import cached_property
3+
from itertools import product
4+
from typing import Any, Iterable, Sequence, TypeVar
5+
6+
from kirin import types
7+
from kirin.dialects import ilist
8+
9+
from bloqade.geometry.dialects import grid
10+
11+
NumX = TypeVar("NumX")
12+
NumY = TypeVar("NumY")
13+
14+
15+
@dataclass(eq=False)
16+
class FilledGrid(grid.Grid[NumX, NumY]):
17+
x_spacing: tuple[float, ...] = field(init=False)
18+
y_spacing: tuple[float, ...] = field(init=False)
19+
x_init: float | None = field(init=False)
20+
y_init: float | None = field(init=False)
21+
22+
parent: grid.Grid[NumX, NumY]
23+
vacancies: frozenset[tuple[int, int]]
24+
25+
def __post_init__(self):
26+
self.x_spacing = self.parent.x_spacing
27+
self.y_spacing = self.parent.y_spacing
28+
self.x_init = self.parent.x_init
29+
self.y_init = self.parent.y_init
30+
31+
self.type = types.Generic(
32+
FilledGrid,
33+
types.Literal(len(self.x_spacing) + 1),
34+
types.Literal(len(self.y_spacing) + 1),
35+
)
36+
37+
def __hash__(self):
38+
return hash((self.parent, self.vacancies))
39+
40+
def __eq__(self, other: Any) -> bool:
41+
return (
42+
isinstance(other, FilledGrid)
43+
and self.parent == other.parent
44+
and self.vacancies == other.vacancies
45+
)
46+
47+
def is_equal(self, other: Any) -> bool:
48+
return self == other
49+
50+
@cached_property
51+
def positions(self) -> ilist.IList[tuple[float, float], Any]:
52+
positions = tuple(
53+
(x, y)
54+
for (ix, x), (iy, y) in product(
55+
enumerate(self.x_positions), enumerate(self.y_positions)
56+
)
57+
if (ix, iy) not in self.vacancies
58+
)
59+
60+
return ilist.IList(positions)
61+
62+
@classmethod
63+
def fill(
64+
cls, grid_obj: grid.Grid[NumX, NumY], filled: Sequence[tuple[int, int]]
65+
) -> "FilledGrid[NumX, NumY]":
66+
num_x, num_y = grid_obj.shape
67+
68+
if isinstance(grid_obj, FilledGrid):
69+
vacancies = grid_obj.vacancies
70+
parent = grid_obj.parent
71+
else:
72+
vacancies = frozenset(product(range(num_x), range(num_y)))
73+
parent = grid_obj
74+
75+
vacancies = vacancies - frozenset(filled)
76+
77+
return cls(parent=parent, vacancies=vacancies)
78+
79+
@classmethod
80+
def vacate(
81+
cls, grid_obj: grid.Grid[NumX, NumY], vacancies: Iterable[tuple[int, int]]
82+
) -> "FilledGrid[NumX, NumY]":
83+
84+
if isinstance(grid_obj, FilledGrid):
85+
input_vacancies = grid_obj.vacancies
86+
parent = grid_obj.parent
87+
else:
88+
input_vacancies = frozenset()
89+
parent = grid_obj
90+
91+
input_vacancies = input_vacancies.union(vacancies)
92+
93+
return cls(parent=parent, vacancies=input_vacancies)
94+
95+
def get_view( # type: ignore
96+
self, x_indices: ilist.IList[int, Any], y_indices: ilist.IList[int, Any]
97+
):
98+
remapping_x = {ix: i for i, ix in enumerate(x_indices)}
99+
remapping_y = {iy: i for i, iy in enumerate(y_indices)}
100+
return FilledGrid(
101+
parent=self.parent.get_view(x_indices, y_indices),
102+
vacancies=frozenset(
103+
(remapping_x[x], remapping_y[y])
104+
for x, y in self.vacancies
105+
if x in remapping_x and y in remapping_y
106+
),
107+
)
108+
109+
def shift(self, x_shift: float, y_shift: float):
110+
return FilledGrid(
111+
parent=self.parent.shift(x_shift, y_shift),
112+
vacancies=self.vacancies,
113+
)
114+
115+
def scale(self, x_scale: float, y_scale: float):
116+
return FilledGrid(
117+
parent=self.parent.scale(x_scale, y_scale),
118+
vacancies=self.vacancies,
119+
)
120+
121+
def repeat(self, x_times: int, y_times: int, x_gap: float, y_gap: float):
122+
new_parent = self.parent.repeat(x_times, y_times, x_gap, y_gap)
123+
x_dim, y_dim = self.shape
124+
vacancies = frozenset(
125+
(x + x_dim * i, y + y_dim * j)
126+
for i, j, (x, y) in product(range(x_times), range(y_times), self.vacancies)
127+
)
128+
return FilledGrid.vacate(new_parent, vacancies)
129+
130+
def row_x_pos(self, row_index: int):
131+
x_vacancies = {x for x, y in self.vacancies if y == row_index}
132+
return ilist.IList(
133+
[x for i, x in enumerate(self.parent.x_positions) if i not in x_vacancies]
134+
)
135+
136+
def col_y_pos(self, column_index: int):
137+
y_vacancies = {y for x, y in self.vacancies if x == column_index}
138+
return ilist.IList(
139+
[y for i, y in enumerate(self.y_positions) if i not in y_vacancies]
140+
)
141+
142+
143+
FilledGridType = types.Generic(FilledGrid, types.TypeVar("NumX"), types.TypeVar("NumY"))

src/bloqade/geometry/dialects/grid/_interface.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from kirin.lowering import wraps as _wraps
55

66
from .stmts import (
7+
ColYPos,
78
FromPositions,
89
Get,
910
GetSubGrid,
@@ -14,6 +15,7 @@
1415
New,
1516
Positions,
1617
Repeat,
18+
RowXPos,
1719
Scale,
1820
Shape,
1921
Shift,
@@ -275,3 +277,33 @@ def shape(grid: Grid) -> tuple[int, int]:
275277
tuple[int, int]: a tuple of (num_x, num_y)
276278
"""
277279
...
280+
281+
282+
@_wraps(RowXPos)
283+
def row_xpos(
284+
grid: Grid[typing.Any, typing.Any], row_index: int
285+
) -> ilist.IList[float, typing.Any]:
286+
"""Get the x positions of a specific row in the grid.
287+
288+
Args:
289+
grid (Grid): a grid object
290+
row_index (int): the index of the row to get x positions for
291+
Returns:
292+
ilist.IList[float, typing.Any]: a list of x positions for the specified row
293+
"""
294+
...
295+
296+
297+
@_wraps(ColYPos)
298+
def col_ypos(
299+
grid: Grid[typing.Any, typing.Any], column_index: int
300+
) -> ilist.IList[float, typing.Any]:
301+
"""Get the y positions of a specific column in the grid.
302+
303+
Args:
304+
grid (Grid): a grid object
305+
column_index (int | None): the index of the column to get y positions for, or None for all columns
306+
Returns:
307+
ilist.IList[float, typing.Any]: a list of y positions for the specified column
308+
"""
309+
...

0 commit comments

Comments
 (0)