Skip to content

Commit cd75113

Browse files
weinbe58claude
andauthored
refactor(python): port IR address validation from Python to Rust (#348)
* refactor(python): port IR address validation from Python to Rust Replace pure-Python validation methods with Rust-backed group validators. The IR validation pass now delegates entire groups of addresses to Rust in a single call instead of validating one-by-one. Changes: - Add check_location_group() and check_lane_group() wrapper methods on ArchSpec that extract _inner Rust objects and call the existing check_locations/check_lanes PyO3 bindings - Rewrite validate_location() and validate_lane() to delegate to Rust via single-element group calls (backwards compatible) - Rewrite compatible_lane_error() and compatible_lanes() to delegate to Rust group validation - Simplify validation/address.py: replace filter_by_error loops and pairwise compatibility checks with single group validation calls - Remove ~80 lines of pure-Python validation logic from ArchSpec - Update test expected error messages to match Rust format Closes #287 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(python): address PR review feedback for Rust address validation - Tighten return types of check_location_group and check_lane_group to Sequence[LocationGroupError] and Sequence[LaneGroupError] - Add TYPE_CHECKING import for the error types - Include lane count in validation error messages for diagnostics Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * fix(python): use proper return types in _native.pyi for check_locations/check_lanes Update type stubs to return list[LocationGroupError] and list[LaneGroupError] instead of list[Exception], removing the need for type: ignore comments in the wrapper methods. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 96a6d48 commit cd75113

File tree

4 files changed

+81
-192
lines changed

4 files changed

+81
-192
lines changed

python/bloqade/lanes/bytecode/_native.pyi

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
from typing import Optional, final
44

5+
from bloqade.lanes.bytecode.exceptions import (
6+
LaneGroupError,
7+
LocationGroupError,
8+
)
9+
510
# ── Enums ──
611

712
@final
@@ -724,7 +729,9 @@ class ArchSpec:
724729
"""
725730
...
726731

727-
def check_locations(self, locations: list[LocationAddress]) -> list[Exception]:
732+
def check_locations(
733+
self, locations: list[LocationAddress]
734+
) -> list[LocationGroupError]:
728735
"""Validate a group of location addresses against this architecture.
729736
730737
Checks for duplicate addresses and invalid word/site combinations.
@@ -733,11 +740,11 @@ class ArchSpec:
733740
locations (list[LocationAddress]): Location addresses to validate.
734741
735742
Returns:
736-
list[Exception]: ``LocationGroupError`` subclass instances (empty if all valid).
743+
list[LocationGroupError]: Error instances (empty if all valid).
737744
"""
738745
...
739746

740-
def check_lanes(self, lanes: list[LaneAddress]) -> list[Exception]:
747+
def check_lanes(self, lanes: list[LaneAddress]) -> list[LaneGroupError]:
741748
"""Validate a group of lane addresses against this architecture.
742749
743750
Checks for duplicates, invalid addresses, bus consistency, and
@@ -747,7 +754,7 @@ class ArchSpec:
747754
lanes (list[LaneAddress]): Lane addresses to validate.
748755
749756
Returns:
750-
list[Exception]: ``LaneGroupError`` subclass instances (empty if all valid).
757+
list[LaneGroupError]: Error instances (empty if all valid).
751758
"""
752759
...
753760

python/bloqade/lanes/layout/arch.py

Lines changed: 39 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
if TYPE_CHECKING:
3030
from collections.abc import Iterator
3131

32+
from bloqade.lanes.bytecode.exceptions import LaneGroupError, LocationGroupError
33+
3234

3335
class ArchSpec:
3436
"""Architecture specification for a quantum device."""
@@ -378,104 +380,62 @@ def show(
378380
)
379381
plt.show()
380382

383+
def check_location_group(
384+
self, locations: Sequence[LocationAddress]
385+
) -> Sequence[LocationGroupError]:
386+
"""Validate a group of location addresses via Rust.
387+
388+
Returns a list of LocationGroupError exceptions (empty if all valid).
389+
"""
390+
rust_addrs = [loc._inner for loc in locations]
391+
return self._inner.check_locations(rust_addrs)
392+
393+
def check_lane_group(
394+
self, lanes: Sequence[LaneAddress]
395+
) -> Sequence[LaneGroupError]:
396+
"""Validate a group of lane addresses via Rust.
397+
398+
Checks individual lane validity, group consistency (direction, bus_id,
399+
move_type), bus membership, and AOD geometry constraints.
400+
Returns a list of LaneGroupError exceptions (empty if all valid).
401+
"""
402+
rust_addrs = [lane._inner for lane in lanes]
403+
return self._inner.check_lanes(rust_addrs)
404+
381405
def compatible_lane_error(self, lane1: LaneAddress, lane2: LaneAddress) -> set[str]:
382406
"""Get error messages if two lanes are not compatible.
383407
384-
NOTE: this function assumes that both lanes are valid.
408+
Delegates to Rust group validation.
385409
"""
386-
errors: set[str] = set()
387-
if lane1.direction != lane2.direction:
388-
errors.add("Lanes have different directions")
389-
390-
if lane1.move_type == MoveType.SITE and lane2.move_type == MoveType.SITE:
391-
if lane1.bus_id != lane2.bus_id:
392-
errors.add("Lanes are on different site-buses")
393-
if lane1.word_id == lane2.word_id and lane1.site_id == lane2.site_id:
394-
errors.add("Lanes are the same")
395-
elif lane1.move_type == MoveType.WORD and lane2.move_type == MoveType.WORD:
396-
if lane2.bus_id != lane1.bus_id:
397-
errors.add("Lanes are on different word-buses")
398-
if lane1.word_id == lane2.word_id and lane1.site_id == lane2.site_id:
399-
errors.add("Lanes are the same")
400-
else:
401-
errors.add("Lanes have different move types")
402-
403-
return errors
410+
errors = self.check_lane_group([lane1, lane2])
411+
return {str(e) for e in errors}
404412

405413
def compatible_lanes(self, lane1: LaneAddress, lane2: LaneAddress) -> bool:
406414
"""Check if two lanes are compatible (can be executed in parallel)."""
407-
return len(self.compatible_lane_error(lane1, lane2)) == 0
415+
return len(self.check_lane_group([lane1, lane2])) == 0
408416

409417
def validate_location(self, location_address: LocationAddress) -> set[str]:
410-
"""Check if a location address is valid in this architecture."""
411-
errors: set[str] = set()
418+
"""Check if a location address is valid in this architecture.
412419
413-
num_words = len(self.words)
414-
if location_address.word_id >= num_words:
415-
errors.add(
416-
f"Word id {location_address.word_id} out of range of {num_words}"
417-
)
418-
return errors
419-
420-
word = self.words[location_address.word_id]
420+
Delegates to Rust validation.
421+
"""
422+
errors = self.check_location_group([location_address])
423+
return {str(e) for e in errors}
421424

422-
num_sites = len(word.site_indices)
423-
if location_address.site_id >= num_sites:
424-
errors.add(
425-
f"Site id {location_address.site_id} out of range of {num_sites}"
426-
)
425+
def validate_lane(self, lane_address: LaneAddress) -> set[str]:
426+
"""Check if a lane address is valid in this architecture.
427427
428-
return errors
428+
Delegates to Rust validation.
429+
"""
430+
errors = self.check_lane_group([lane_address])
431+
return {str(e) for e in errors}
429432

430433
def get_lane_address(
431434
self, src: LocationAddress, dst: LocationAddress
432435
) -> LaneAddress | None:
433436
"""Given an input tuple of locations, gets the lane (w/direction)."""
434437
return self._lane_map.get((src, dst))
435438

436-
def validate_lane(self, lane_address: LaneAddress) -> set[str]:
437-
"""Check if a lane address is valid in this architecture."""
438-
errors = self.validate_location(lane_address.src_site())
439-
440-
if lane_address.move_type == MoveType.WORD:
441-
if lane_address.site_id not in self.has_word_buses:
442-
errors.add(
443-
f"Site {lane_address.site_id} does not support word-bus moves"
444-
)
445-
num_word_buses = len(self.word_buses)
446-
if lane_address.bus_id >= num_word_buses:
447-
errors.add(
448-
f"Bus id {lane_address.bus_id} out of range of {num_word_buses}"
449-
)
450-
return errors
451-
452-
bus = self.word_buses[lane_address.bus_id]
453-
if lane_address.word_id not in bus.src:
454-
errors.add(f"Word {lane_address.word_id} not in bus source {bus.src}")
455-
456-
elif lane_address.move_type == MoveType.SITE:
457-
if lane_address.word_id not in self.has_site_buses:
458-
errors.add(
459-
f"Word {lane_address.word_id} does not support site-bus moves"
460-
)
461-
462-
num_site_buses = len(self.site_buses)
463-
if lane_address.bus_id >= num_site_buses:
464-
errors.add(
465-
f"Bus id {lane_address.bus_id} out of range of {num_site_buses}"
466-
)
467-
return errors
468-
469-
bus = self.site_buses[lane_address.bus_id]
470-
if lane_address.site_id not in bus.src:
471-
errors.add(f"Site {lane_address.site_id} not in bus source {bus.src}")
472-
else:
473-
errors.add(
474-
f"Unsupported move type {lane_address.move_type} for lane address"
475-
)
476-
477-
return errors
478-
479439
def get_endpoints(
480440
self, lane_address: LaneAddress
481441
) -> tuple[LocationAddress, LocationAddress]:

python/bloqade/lanes/validation/address.py

Lines changed: 29 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from itertools import chain
3-
from typing import Any, Callable, Iterable, TypeVar
3+
from typing import Any
44

55
from kirin import interp, ir
66
from kirin.analysis.forward import Forward, ForwardFrame
@@ -9,7 +9,7 @@
99

1010
from bloqade.lanes.dialects import move
1111
from bloqade.lanes.layout.arch import ArchSpec
12-
from bloqade.lanes.layout.encoding import Encoder, LaneAddress
12+
from bloqade.lanes.layout.encoding import LaneAddress, LocationAddress
1313

1414

1515
@dataclass
@@ -25,29 +25,29 @@ def method_self(self, method: ir.Method) -> EmptyLattice:
2525
def eval_fallback(self, frame: ForwardFrame[EmptyLattice], node: ir.Statement):
2626
return tuple(EmptyLattice.bottom() for _ in node.results)
2727

28-
AddressType = TypeVar("AddressType", bound=Encoder)
29-
30-
def filter_by_error(
31-
self,
32-
addresses: Iterable[AddressType],
33-
checker: Callable[[AddressType], set[str]],
28+
def report_location_errors(
29+
self, node: ir.Statement, locations: tuple[LocationAddress, ...]
3430
):
35-
"""Apply a checker function to a sequence of addresses, yielding those with errors
36-
along with their error messages.
37-
38-
Args:
39-
addresses: A tuple of address objects to be checked.
40-
checker: A function that takes an address and returns a set of error messages.
41-
if the set is empty, the address is considered valid.
42-
Yields:
43-
Tuples of (address, error message) for each address that has an error.
44-
"""
45-
46-
def has_error(tup: tuple[Any, set[str]]) -> bool:
47-
return len(tup[1]) > 0
31+
"""Validate a group of locations via Rust and report errors."""
32+
for error in self.arch_spec.check_location_group(locations):
33+
self.add_validation_error(
34+
node,
35+
ir.ValidationError(
36+
node,
37+
f"Invalid location address: {error}",
38+
),
39+
)
4840

49-
error_checks = zip(addresses, map(checker, addresses))
50-
yield from filter(has_error, error_checks)
41+
def report_lane_errors(self, node: ir.Statement, lanes: tuple[LaneAddress, ...]):
42+
"""Validate a group of lanes via Rust and report errors."""
43+
for error in self.arch_spec.check_lane_group(lanes):
44+
self.add_validation_error(
45+
node,
46+
ir.ValidationError(
47+
node,
48+
f"Invalid lane group (count={len(lanes)}): {error}",
49+
),
50+
)
5151

5252

5353
@move.dialect.register(key="move.address.validation")
@@ -59,46 +59,8 @@ def lane_checker(
5959
frame: ForwardFrame[EmptyLattice],
6060
node: move.Move,
6161
):
62-
if len(node.lanes) == 0:
63-
return ()
64-
65-
invalid_lanes = []
66-
for lane, error_msgs in _interp.filter_by_error(
67-
node.lanes, _interp.arch_spec.validate_lane
68-
):
69-
invalid_lanes.append(lane)
70-
for error_msg in error_msgs:
71-
_interp.add_validation_error(
72-
node,
73-
ir.ValidationError(
74-
node,
75-
f"Invalid lane address {lane!r}: {error_msg}",
76-
),
77-
)
78-
79-
valid_lanes = set(node.lanes) - set(invalid_lanes)
80-
if len(valid_lanes) == 0:
81-
return
82-
83-
first_lane = valid_lanes.pop()
84-
incompatible_lanes = []
85-
86-
def validate_compatible_lane(lane: LaneAddress):
87-
return _interp.arch_spec.compatible_lane_error(first_lane, lane)
88-
89-
for lane, error_msgs in _interp.filter_by_error(
90-
valid_lanes, validate_compatible_lane
91-
):
92-
incompatible_lanes.append(lane)
93-
for error_msg in error_msgs:
94-
_interp.add_validation_error(
95-
node,
96-
ir.ValidationError(
97-
node,
98-
f"Incompatible lane address {first_lane!r} with lane {lane!r}: {error_msg}",
99-
),
100-
)
101-
62+
if len(node.lanes) > 0:
63+
_interp.report_lane_errors(node, node.lanes)
10264
return (EmptyLattice.bottom(),)
10365

10466
@interp.impl(move.LogicalInitialize)
@@ -111,23 +73,7 @@ def location_checker(
11173
frame: ForwardFrame[EmptyLattice],
11274
node: move.LogicalInitialize | move.LocalR | move.LocalRz | move.Fill,
11375
):
114-
invalid_locations = list(
115-
_interp.filter_by_error(
116-
node.location_addresses,
117-
_interp.arch_spec.validate_location,
118-
)
119-
)
120-
121-
for lane_address, error_msgs in invalid_locations:
122-
for error_msg in error_msgs:
123-
_interp.add_validation_error(
124-
node,
125-
ir.ValidationError(
126-
node,
127-
f"Invalid location address {lane_address!r}: {error_msg}",
128-
),
129-
)
130-
76+
_interp.report_location_errors(node, node.location_addresses)
13177
return (EmptyLattice.bottom(),)
13278

13379
@interp.impl(move.GetFutureResult)
@@ -137,17 +83,7 @@ def location_checker_get_future(
13783
frame: ForwardFrame[EmptyLattice],
13884
node: move.GetFutureResult,
13985
):
140-
location_address = node.location_address
141-
error_msgs = _interp.arch_spec.validate_location(location_address)
142-
143-
for error_msg in error_msgs:
144-
_interp.add_validation_error(
145-
node,
146-
ir.ValidationError(
147-
node,
148-
f"Invalid location address {location_address!r}: {error_msg}",
149-
),
150-
)
86+
_interp.report_location_errors(node, (node.location_address,))
15187

15288
@interp.impl(move.PhysicalInitialize)
15389
def location_checker_physical(
@@ -156,22 +92,8 @@ def location_checker_physical(
15692
frame: ForwardFrame[EmptyLattice],
15793
node: move.PhysicalInitialize,
15894
):
159-
invalid_locations = list(
160-
_interp.filter_by_error(
161-
chain.from_iterable(node.location_addresses),
162-
_interp.arch_spec.validate_location,
163-
)
164-
)
165-
166-
for lane_address, error_msgs in invalid_locations:
167-
for error_msg in error_msgs:
168-
_interp.add_validation_error(
169-
node,
170-
ir.ValidationError(
171-
node,
172-
f"Invalid location address {lane_address!r}: {error_msg}",
173-
),
174-
)
95+
all_locations = tuple(chain.from_iterable(node.location_addresses))
96+
_interp.report_location_errors(node, all_locations)
17597

17698

17799
@dataclass

python/tests/arch/gemini/logical/test_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def plot():
7070

7171
def invalid_locations():
7272
arch_spec = logical.get_arch_spec()
73-
yield arch_spec, LocationAddress(16, 0), set(["Word id 16 out of range of 2"])
74-
yield arch_spec, LocationAddress(0, 32), set(["Site id 32 out of range of 10"])
73+
yield arch_spec, LocationAddress(16, 0), {"invalid location word_id=16, site_id=0"}
74+
yield arch_spec, LocationAddress(0, 32), {"invalid location word_id=0, site_id=32"}
7575

7676

7777
@pytest.mark.parametrize("arch_spec, location_address, message", invalid_locations())

0 commit comments

Comments
 (0)