Skip to content

Commit 5652952

Browse files
committed
adding auto encoding
1 parent d8d7f81 commit 5652952

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

src/bloqade/lanes/types/arch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from typing import Generic, Sequence
33

44
import numpy as np
55

6+
from bloqade.lanes.types.encoding import EncodingType
7+
68
from .word import SiteType, Word
79

810

@@ -33,6 +35,10 @@ class ArchSpec(Generic[SiteType]):
3335
"""List of all word buses in the architecture by word address."""
3436
site_bus_compatibility: tuple[frozenset[int], ...]
3537
"""Mapping from word id indicating which other word ids can execute site-buses in parallel."""
38+
encoding: EncodingType = field(init=False)
39+
40+
def __post_init__(self):
41+
object.__setattr__(self, "encoding", EncodingType.infer(self)) # type: ignore
3642

3743
def plot(
3844
self,

src/bloqade/lanes/types/encoding.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,27 @@ class EncodingType(enum.Enum):
1717
BIT32 = 0
1818
BIT64 = 1
1919

20+
@staticmethod
21+
def infer(spec) -> "EncodingType":
22+
num_words = len(spec.words)
23+
num_sites = len(spec.words[0].sites)
24+
num_site_buses = len(spec.site_buses)
25+
num_word_buses = len(spec.word_buses)
26+
27+
max_id = max(
28+
num_words - 1,
29+
num_sites - 1,
30+
num_site_buses - 1,
31+
num_word_buses - 1,
32+
)
33+
34+
if max_id < 256:
35+
return EncodingType.BIT32
36+
elif max_id < 65536:
37+
return EncodingType.BIT64
38+
else:
39+
raise ValueError("Architecture too large to encode with 64-bit addresses")
40+
2041

2142
@dataclass(frozen=True)
2243
class MoveType(abc.ABC):

src/bloqade/lanes/types/path.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import rustworkx as nx
66

77
from .arch import ArchSpec
8-
from .encoding import Direction, EncodingType, InterMove, IntraMove, MoveType
8+
from .encoding import Direction, InterMove, IntraMove, MoveType
99

1010

1111
@dataclass(frozen=True)
@@ -56,7 +56,7 @@ def __post_init__(self):
5656
lane_addr,
5757
)
5858

59-
def extract_lanes_from_path(self, path: list[int], encoding: EncodingType):
59+
def extract_lanes_from_path(self, path: list[int]):
6060
"""Given a path as a list of node indices, extract the lane addresses."""
6161
if len(path) < 2:
6262
raise ValueError("Path must have at least two nodes to extract lanes.")
@@ -67,15 +67,14 @@ def extract_lanes_from_path(self, path: list[int], encoding: EncodingType):
6767

6868
lane: MoveType = self.site_graph.get_edge_data(src, dst)
6969

70-
lanes.append(lane.get_address(encoding))
70+
lanes.append(lane.get_address(self.spec.encoding))
7171
return lanes
7272

7373
def find_path(
7474
self,
7575
start: tuple[int, int],
7676
end: tuple[int, int],
7777
occupied: frozenset[tuple[int, int]] = frozenset(),
78-
encoding: EncodingType = EncodingType.BIT64,
7978
path_heuristic: Callable[[list[tuple[int, int]]], float] = lambda _: 0.0,
8079
):
8180
"""Find a path from start to end avoiding occupied sites.
@@ -114,7 +113,7 @@ def find_path(
114113
path_nodes,
115114
key=lambda p: path_heuristic([self.physical_addresses[n] for n in p]),
116115
)
117-
lanes = self.extract_lanes_from_path(path, encoding)
116+
lanes = self.extract_lanes_from_path(path)
118117
return lanes, occupied.union(
119118
frozenset(self.physical_addresses[n] for n in path)
120119
)

0 commit comments

Comments
 (0)