Skip to content

Commit 42a564e

Browse files
committed
perf: Port FrozenOrderedSet to rust
1 parent 35c0b24 commit 42a564e

File tree

11 files changed

+606
-56
lines changed

11 files changed

+606
-56
lines changed

bench_frozen_ordered_set.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Benchmark: Rust FrozenOrderedSet vs Python FrozenOrderedSet."""
2+
3+
import sys
4+
import timeit
5+
from collections.abc import Hashable, Iterable, Iterator
6+
from typing import AbstractSet, Any, TypeVar
7+
8+
sys.path.insert(0, "src/python")
9+
10+
from pants.engine.internals.native_engine import FrozenOrderedSet as RustFrozenOrderedSet
11+
12+
T = TypeVar("T")
13+
14+
15+
class PyFrozenOrderedSet(AbstractSet[T], Hashable):
16+
"""The old pure-Python FrozenOrderedSet (pre-port)."""
17+
18+
def __init__(self, iterable=None):
19+
self._items = dict.fromkeys(iterable) if iterable else {}
20+
self._hash = None
21+
22+
def __len__(self):
23+
return len(self._items)
24+
25+
def __contains__(self, key):
26+
return key in self._items
27+
28+
def __iter__(self) -> Iterator:
29+
return iter(self._items)
30+
31+
def __reversed__(self):
32+
return reversed(tuple(self._items.keys()))
33+
34+
def __eq__(self, other):
35+
if not isinstance(other, self.__class__):
36+
return NotImplemented
37+
return len(self._items) == len(other._items) and all(
38+
x == y for x, y in zip(self._items, other._items)
39+
)
40+
41+
def __hash__(self):
42+
if self._hash is None:
43+
self._hash = 0
44+
for item in self._items.keys():
45+
self._hash ^= hash(item)
46+
return self._hash
47+
48+
def __repr__(self):
49+
return f"PyFrozenOrderedSet({list(self)!r})"
50+
51+
def __bool__(self):
52+
return bool(self._items)
53+
54+
def union(self, other):
55+
return self.__class__(list(self) + [x for x in other if x not in self._items])
56+
57+
def intersection(self, other):
58+
s = set(other)
59+
return self.__class__(x for x in self if x in s)
60+
61+
def difference(self, other):
62+
s = set(other)
63+
return self.__class__(x for x in self if x not in s)
64+
65+
def issubset(self, other):
66+
if len(self) > len(other):
67+
return False
68+
return all(item in other for item in self)
69+
70+
71+
WARMUP = 1000
72+
73+
def measure(stmt, number, globs):
74+
timeit.timeit(stmt, number=WARMUP, globals=globs)
75+
t = timeit.timeit(stmt, number=number, globals=globs)
76+
return t / number * 1_000_000
77+
78+
79+
BENCHMARKS = [
80+
("Construction", "Cls(data)", lambda data, py, rs, **_: [
81+
{"Cls": PyFrozenOrderedSet, "data": data},
82+
{"Cls": RustFrozenOrderedSet, "data": data},
83+
]),
84+
("hash()", "hash(fd)", lambda py, rs, **_: [
85+
{"fd": py},
86+
{"fd": rs},
87+
]),
88+
("__contains__", "k in fd", lambda py, rs, mid, **_: [
89+
{"fd": py, "k": mid},
90+
{"fd": rs, "k": mid},
91+
]),
92+
("__contains__ miss", "k in fd", lambda py, rs, **_: [
93+
{"fd": py, "k": "MISSING"},
94+
{"fd": rs, "k": "MISSING"},
95+
]),
96+
("__eq__", "fd == fd2", lambda py, rs, py2, rs2, **_: [
97+
{"fd": py, "fd2": py2},
98+
{"fd": rs, "fd2": rs2},
99+
]),
100+
("iteration", "list(fd)", lambda py, rs, **_: [
101+
{"fd": py},
102+
{"fd": rs},
103+
]),
104+
("union", "fd.union(other)", lambda py, rs, py_other, rs_other, **_: [
105+
{"fd": py, "other": py_other},
106+
{"fd": rs, "other": rs_other},
107+
]),
108+
("intersection", "fd.intersection(other)", lambda py, rs, py_other, rs_other, **_: [
109+
{"fd": py, "other": py_other},
110+
{"fd": rs, "other": rs_other},
111+
]),
112+
("difference", "fd.difference(other)", lambda py, rs, py_other, rs_other, **_: [
113+
{"fd": py, "other": py_other},
114+
{"fd": rs, "other": rs_other},
115+
]),
116+
("issubset", "small.issubset(fd)", lambda py, rs, py_small, rs_small, **_: [
117+
{"small": py_small, "fd": py},
118+
{"small": rs_small, "fd": rs},
119+
]),
120+
("dict key", "d[fd]", lambda py, rs, **_: [
121+
{"fd": py, "d": {py: 1}},
122+
{"fd": rs, "d": {rs: 1}},
123+
]),
124+
]
125+
126+
SMALL = list(range(5))
127+
MEDIUM = list(range(20))
128+
LARGE = list(range(200))
129+
130+
DATASETS = [("small (5)", SMALL), ("medium (20)", MEDIUM), ("large (200)", LARGE)]
131+
132+
all_results: dict[str, dict[str, tuple[float, float]]] = {}
133+
134+
for ds_name, data in DATASETS:
135+
print(f"\n{'=' * 60}")
136+
print(f" Dataset: {ds_name}")
137+
print(f"{'=' * 60}")
138+
139+
py = PyFrozenOrderedSet(data)
140+
rs = RustFrozenOrderedSet(data)
141+
py2 = PyFrozenOrderedSet(data)
142+
rs2 = RustFrozenOrderedSet(data)
143+
half = data[:len(data) // 2]
144+
py_other = PyFrozenOrderedSet(half + list(range(1000, 1000 + len(half))))
145+
rs_other = RustFrozenOrderedSet(half + list(range(1000, 1000 + len(half))))
146+
py_small = PyFrozenOrderedSet(data[:3])
147+
rs_small = RustFrozenOrderedSet(data[:3])
148+
# Warm up lazy hashes
149+
for obj in (py, py2, py_other, py_small):
150+
hash(obj)
151+
n = 500_000 if len(data) <= 20 else 50_000
152+
mid = data[len(data) // 2]
153+
154+
ctx = dict(data=data, py=py, rs=rs, py2=py2, rs2=rs2, mid=mid,
155+
py_other=py_other, rs_other=rs_other, py_small=py_small, rs_small=rs_small)
156+
157+
for bench_name, stmt, make_globs in BENCHMARKS:
158+
py_globs, rs_globs = make_globs(**ctx)
159+
py_us = measure(stmt, n, py_globs)
160+
rs_us = measure(stmt, n, rs_globs)
161+
print(f" {bench_name:.<20s} Python {py_us:8.3f} µs Rust {rs_us:8.3f} µs ({py_us / rs_us:.1f}x)")
162+
all_results.setdefault(bench_name, {})[ds_name] = (py_us, rs_us)
163+
164+
ds_names = [name for name, _ in DATASETS]
165+
header = f" {'Operation':<20s}" + "".join(f" | {name:>12s}" for name in ds_names)
166+
sep = f" {'-'*20}" + "".join(f"-+-{'-'*12}" for _ in ds_names)
167+
168+
print(f"\n{'=' * 60}")
169+
print(" Summary (Python / Rust speedup)")
170+
print(f"{'=' * 60}")
171+
print(header)
172+
print(sep)
173+
for bench_name, _, _ in BENCHMARKS:
174+
row = f" {bench_name:<20s}"
175+
for ds_name in ds_names:
176+
py_us, rs_us = all_results[bench_name][ds_name]
177+
ratio = py_us / rs_us
178+
row += f" | {ratio:11.1f}x"
179+
print(row)

src/python/pants/backend/python/util_rules/interpreter_constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ def for_fixed_python_version(
8989
) -> InterpreterConstraints:
9090
return cls([f"{interpreter_type}=={python_version_str}"])
9191

92-
def __init__(self, constraints: Iterable[str | Requirement] = ()) -> None:
92+
def __new__(cls, constraints: Iterable[str | Requirement] = ()) -> InterpreterConstraints:
9393
# #12578 `parse_constraint` will sort the requirement's component constraints into a stable form.
9494
# We need to sort the component constraints for each requirement _before_ sorting the entire list
9595
# for the ordering to be correct.
9696
parsed_constraints = (
9797
i if isinstance(i, Requirement) else parse_constraint(i) for i in constraints
9898
)
99-
super().__init__(sorted(parsed_constraints, key=lambda c: str(c)))
99+
return super().__new__(cls, sorted(parsed_constraints, key=lambda c: str(c)))
100100

101101
def __str__(self) -> str:
102102
return " OR ".join(str(constraint) for constraint in self)

src/python/pants/backend/python/util_rules/pex.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class CompletePlatforms(DeduplicatedCollection[str]):
133133
sort_input = True
134134

135135
def __init__(self, iterable: Iterable[str] = (), *, digest: Digest = EMPTY_DIGEST):
136-
super().__init__(iterable)
137136
self._digest = digest
138137

139138
@classmethod

src/python/pants/engine/collection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ class Examples(DeduplicatedCollection[Example]):
7979

8080
sort_input: ClassVar[bool] = False
8181

82-
def __init__(self, iterable: Iterable[T] = ()) -> None:
83-
super().__init__(
84-
iterable if not self.sort_input else sorted(iterable) # type: ignore[type-var]
82+
def __new__(cls, iterable: Iterable[T] = (), **_kwargs: object) -> DeduplicatedCollection[T]:
83+
return super().__new__(
84+
cls,
85+
iterable if not cls.sort_input else sorted(iterable), # type: ignore[type-var]
8586
)
8687

8788
def __repr__(self) -> str:
88-
return f"{self.__class__.__name__}({list(self._items)})"
89+
return f"{self.__class__.__name__}({list(self)})"

src/python/pants/engine/internals/native_engine.pyi

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from __future__ import annotations
88

9-
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
9+
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence
1010
from datetime import datetime
1111
from io import RawIOBase
1212
from pathlib import Path
13-
from typing import Any, ClassVar, Protocol, Self, TextIO, TypeVar, overload
13+
from typing import AbstractSet, Any, ClassVar, Protocol, Self, TextIO, TypeVar, overload
1414

1515
from pants.engine.fs import (
1616
CreateDigest,
@@ -81,6 +81,34 @@ class FrozenDict(Mapping[K, V]):
8181
def __hash__(self) -> int: ...
8282
def __repr__(self) -> str: ...
8383

84+
T_co = TypeVar("T_co", covariant=True)
85+
86+
class FrozenOrderedSet(AbstractSet[T_co], Hashable):
87+
"""A frozen (i.e. immutable) ordered set backed by Rust.
88+
89+
This is safe to use with the V2 engine.
90+
"""
91+
92+
def __new__(cls, iterable: Iterable[T_co] | None = None) -> Self: ...
93+
def __len__(self) -> int: ...
94+
def __contains__(self, key: Any) -> bool: ...
95+
def __iter__(self) -> Iterator[T_co]: ...
96+
def __reversed__(self) -> Iterator[T_co]: ...
97+
def __hash__(self) -> int: ...
98+
def __eq__(self, other: Any) -> bool: ...
99+
def __or__(self, other: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ... # type: ignore[override] # widens from AbstractSet
100+
def __and__(self, other: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
101+
def __sub__(self, other: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
102+
def __xor__(self, other: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ... # type: ignore[override] # widens from AbstractSet
103+
def __bool__(self) -> bool: ...
104+
def __repr__(self) -> str: ...
105+
def union(self, *others: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
106+
def intersection(self, *others: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
107+
def difference(self, *others: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
108+
def symmetric_difference(self, other: Iterable[T_co]) -> FrozenOrderedSet[T_co]: ...
109+
def issubset(self, other: Iterable[T_co]) -> bool: ...
110+
def issuperset(self, other: Iterable[T_co]) -> bool: ...
111+
84112
# ------------------------------------------------------------------------------
85113
# Address
86114
# ------------------------------------------------------------------------------

src/python/pants/util/ordered_set.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from __future__ import annotations
1616

1717
import itertools
18-
from collections.abc import Hashable, Iterable, Iterator, MutableSet
18+
from collections.abc import Iterable, Iterator, MutableSet
1919
from typing import AbstractSet, Any, TypeVar, cast
2020

21+
from pants.engine.internals.native_engine import FrozenOrderedSet as FrozenOrderedSet # noqa: F401
22+
2123
T = TypeVar("T")
2224
T_co = TypeVar("T_co", covariant=True)
2325
_TAbstractOrderedSet = TypeVar("_TAbstractOrderedSet", bound="_AbstractOrderedSet")
@@ -195,21 +197,3 @@ def symmetric_difference_update(self, other: Iterable[T]) -> None:
195197
self._items = {item: None for item in self._items.keys() if item not in items_to_remove}
196198
for item in items_to_add:
197199
self._items[item] = None
198-
199-
200-
class FrozenOrderedSet(_AbstractOrderedSet[T_co], Hashable): # type: ignore[type-var]
201-
"""A frozen (i.e. immutable) set that retains its order.
202-
203-
This is safe to use with the V2 engine.
204-
"""
205-
206-
def __init__(self, iterable: Iterable[T_co] | None = None) -> None:
207-
super().__init__(iterable)
208-
self.__hash: int | None = None
209-
210-
def __hash__(self) -> int:
211-
if self.__hash is None:
212-
self.__hash = 0
213-
for item in self._items.keys():
214-
self.__hash ^= hash(item)
215-
return self.__hash

0 commit comments

Comments
 (0)