Skip to content

Commit 35c0b24

Browse files
committed
perf: Port frozendict to rust
1 parent da60c64 commit 35c0b24

File tree

8 files changed

+420
-155
lines changed

8 files changed

+420
-155
lines changed

src/python/pants/backend/go/util_rules/third_party_pkg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def debug_hint(self) -> str:
115115

116116

117117
@dataclass(frozen=True)
118-
class AllThirdPartyPackages(FrozenDict[str, ThirdPartyPkgAnalysis]):
118+
class AllThirdPartyPackages:
119119
"""All the packages downloaded from a go.mod, along with a digest of the downloaded files.
120120
121121
The digest has files in the format `gopath/pkg/mod`, which is what `GoSdkProcess` sets `GOPATH`

src/python/pants/backend/url_handlers/s3/register.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ async def download_from_s3(
189189
NativeDownloadFile(
190190
url=virtual_hosted_url,
191191
expected_digest=request.expected_digest,
192-
auth_headers=http_request.headers,
192+
auth_headers=dict(http_request.headers),
193193
retry_delay_duration=global_options.file_downloads_retry_delay,
194194
max_attempts=global_options.file_downloads_max_attempts,
195195
)

src/python/pants/engine/internals/native_engine.pyi

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from __future__ import annotations
88

9-
from collections.abc import Callable, Iterable, Mapping, Sequence
9+
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
1010
from datetime import datetime
1111
from io import RawIOBase
1212
from pathlib import Path
@@ -49,6 +49,38 @@ from pants.engine.process import (
4949
class PyFailure:
5050
def get_error(self) -> Exception | None: ...
5151

52+
K = TypeVar("K")
53+
V = TypeVar("V")
54+
55+
class FrozenDict(Mapping[K, V]):
56+
"""A wrapper around a normal `dict` that removes all methods to mutate the instance and that
57+
implements __hash__.
58+
59+
This should be used instead of normal dicts when working with the engine because normal dicts
60+
are not safe to use.
61+
"""
62+
63+
@overload
64+
def __new__(cls, __items: Iterable[tuple[K, V]], **kwargs: V) -> Self: ...
65+
@overload
66+
def __new__(cls, __other: Mapping[K, V], **kwargs: V) -> Self: ...
67+
@overload
68+
def __new__(cls, **kwargs: V) -> Self: ...
69+
@classmethod
70+
def deep_freeze(cls, data: Mapping[K, V]) -> Self: ...
71+
@staticmethod
72+
def frozen(to_freeze: Mapping[K, V]) -> FrozenDict[K, V]: ...
73+
def __getitem__(self, k: K) -> V: ...
74+
def __len__(self) -> int: ...
75+
def __iter__(self) -> Iterator[K]: ...
76+
def __reversed__(self) -> Iterator[K]: ...
77+
def __eq__(self, other: Any) -> Any: ...
78+
def __lt__(self, other: Any) -> bool: ...
79+
def __or__(self, other: Any) -> FrozenDict[K, V]: ...
80+
def __ror__(self, other: Any) -> FrozenDict[K, V]: ...
81+
def __hash__(self) -> int: ...
82+
def __repr__(self) -> str: ...
83+
5284
# ------------------------------------------------------------------------------
5385
# Address
5486
# ------------------------------------------------------------------------------

src/python/pants/engine/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ def require_unparametrized_overrides(self) -> dict[str, Mapping[str, Any]]:
12141214
class GeneratedTargets(FrozenDict[Address, Target]):
12151215
"""A mapping of the address of generated targets to the targets themselves."""
12161216

1217-
def __init__(self, generator: Target, generated_targets: Iterable[Target]) -> None:
1217+
def __new__(cls, generator: Target, generated_targets: Iterable[Target]) -> GeneratedTargets:
12181218
expected_spec_path = generator.address.spec_path
12191219
expected_tgt_name = generator.address.target_name
12201220
mapping = {}
@@ -1242,7 +1242,7 @@ def __init__(self, generator: Target, generated_targets: Iterable[Target]) -> No
12421242
"Consider using `request.generator.address.create_generated()`."
12431243
)
12441244
mapping[tgt.address] = tgt
1245-
super().__init__(mapping)
1245+
return super().__new__(cls, mapping)
12461246

12471247

12481248
@rule(polymorphic=True)

src/python/pants/util/frozendict.py

Lines changed: 33 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -3,168 +3,60 @@
33

44
from __future__ import annotations
55

6-
from collections.abc import Callable, Iterable, Iterator, Mapping
7-
from typing import Any, Self, TypeVar, cast, overload
6+
from collections.abc import Callable, ItemsView, Iterable, Mapping, ValuesView
7+
from typing import TypeVar, cast, overload
88

9+
from pants.engine.internals.native_engine import FrozenDict as FrozenDict
910
from pants.util.memo import memoized_method
10-
from pants.util.strutil import softwrap
1111

1212
K = TypeVar("K")
1313
V = TypeVar("V")
14-
15-
16-
class FrozenDict(Mapping[K, V]):
17-
"""A wrapper around a normal `dict` that removes all methods to mutate the instance and that
18-
implements __hash__.
19-
20-
This should be used instead of normal dicts when working with the engine because normal dicts
21-
are not safe to use.
22-
"""
23-
24-
@overload
25-
def __init__(self, __items: Iterable[tuple[K, V]], **kwargs: V) -> None: ...
26-
27-
@overload
28-
def __init__(self, __other: Mapping[K, V], **kwargs: V) -> None: ...
29-
30-
@overload
31-
def __init__(self, **kwargs: V) -> None: ...
32-
33-
def __init__(self, *item: Mapping[K, V] | Iterable[tuple[K, V]], **kwargs: V) -> None:
34-
"""Creates a `FrozenDict` with arguments accepted by `dict` that also must be hashable."""
35-
if len(item) > 1:
36-
raise ValueError(
37-
f"{type(self).__name__} was called with {len(item)} positional arguments but it expects one."
38-
)
39-
40-
# NB: Keep the variable name `_data` in sync with `externs/mod.rs`.
41-
self._data = dict(item[0]) if item else dict()
42-
self._data.update(**kwargs) # type: ignore[call-overload]
43-
44-
# NB: We eagerly compute the hash to validate that the values are hashable and to avoid
45-
# performing the calculation multiple times. This can be revisited if it's found to be a
46-
# performance bottleneck.
47-
self._hash = self._calculate_hash()
48-
49-
@classmethod
50-
def deep_freeze(cls, data: Mapping[K, V]) -> Self:
51-
"""Convert mutable values to their frozen counter parts.
52-
53-
Sets and lists are turned into tuples and dicts into FrozenDicts.
54-
"""
55-
56-
def _freeze(obj):
57-
if isinstance(obj, dict):
58-
return cls.deep_freeze(obj)
59-
if isinstance(obj, (list, set)):
60-
return tuple(map(_freeze, obj))
61-
return obj
62-
63-
return cls({k: _freeze(v) for k, v in data.items()})
64-
65-
@staticmethod
66-
def frozen(to_freeze: Mapping[K, V]) -> FrozenDict[K, V]:
67-
"""Returns a `FrozenDict` containing the keys and values of `to_freeze`.
68-
69-
If `to_freeze` is already a `FrozenDict`, returns the same object.
70-
"""
71-
72-
return to_freeze if isinstance(to_freeze, FrozenDict) else FrozenDict(to_freeze)
73-
74-
def __getitem__(self, k: K) -> V:
75-
return self._data[k]
76-
77-
def __len__(self) -> int:
78-
return len(self._data)
79-
80-
def __iter__(self) -> Iterator[K]:
81-
return iter(self._data)
82-
83-
def __reversed__(self) -> Iterator[K]:
84-
return reversed(tuple(self._data))
85-
86-
def __eq__(self, other: Any) -> Any:
87-
# defer to dict's __eq__
88-
return self._data == other
89-
90-
def __lt__(self, other: Any) -> bool:
91-
if not isinstance(other, FrozenDict):
92-
return NotImplemented
93-
# If sorting each of these on every __lt__ call ends up being a problem we could consider
94-
# optimising this, by, for instance, sorting on construction.
95-
return sorted(self._data.items()) < sorted(other._data.items())
96-
97-
def __or__(self, other: Any) -> FrozenDict[K, V]:
98-
if isinstance(other, FrozenDict):
99-
other = other._data
100-
elif not isinstance(other, Mapping):
101-
return NotImplemented
102-
return FrozenDict(self._data | other)
103-
104-
def __ror__(self, other: Any) -> FrozenDict[K, V]:
105-
if isinstance(other, FrozenDict):
106-
other = other._data
107-
elif not isinstance(other, Mapping):
108-
return NotImplemented
109-
return FrozenDict(other | self._data)
110-
111-
def _calculate_hash(self) -> int:
112-
try:
113-
h = 0
114-
for pair in self._data.items():
115-
# xor is commutative, i.e. we get the same hash no matter the order of items. This
116-
# "relies" on "hash" of the individual elements being unpredictable enough that such
117-
# a naive aggregation is okay. In addition, the Python hash isn't / shouldn't be
118-
# used for cryptographically sensitive purposes.
119-
h ^= hash(pair)
120-
return h
121-
except TypeError as e:
122-
raise TypeError(
123-
softwrap(
124-
f"""
125-
Even though you are using a `{type(self).__name__}`, the underlying values are
126-
not hashable. Please use hashable (and preferably immutable) types for the
127-
underlying values, e.g. use tuples instead of lists and use FrozenOrderedSet
128-
instead of set().
129-
130-
Original error message: {e}
131-
132-
Value: {self}
133-
"""
134-
)
135-
)
136-
137-
def __hash__(self) -> int:
138-
return self._hash
139-
140-
def __repr__(self) -> str:
141-
return f"{type(self).__name__}({self._data!r})"
14+
T = TypeVar("T")
14215

14316

14417
class LazyFrozenDict(FrozenDict[K, V]):
14518
"""A lazy version of `FrozenDict` where the values are not loaded until referenced."""
14619

14720
@overload
148-
def __init__(
149-
self, __items: Iterable[tuple[K, Callable[[], V]]], **kwargs: Callable[[], V]
150-
) -> None: ...
21+
def __new__(
22+
cls, __items: Iterable[tuple[K, Callable[[], V]]], **kwargs: Callable[[], V]
23+
) -> LazyFrozenDict[K, V]: ...
15124

15225
@overload
153-
def __init__(self, __other: Mapping[K, Callable[[], V]], **kwargs: Callable[[], V]) -> None: ...
26+
def __new__(
27+
cls, __other: Mapping[K, Callable[[], V]], **kwargs: Callable[[], V]
28+
) -> LazyFrozenDict[K, V]: ...
15429

15530
@overload
156-
def __init__(self, **kwargs: Callable[[], V]) -> None: ...
31+
def __new__(cls, **kwargs: Callable[[], V]) -> LazyFrozenDict[K, V]: ...
15732

158-
def __init__(
159-
self,
33+
def __new__(
34+
cls,
16035
*item: Mapping[K, Callable[[], V]] | Iterable[tuple[K, Callable[[], V]]],
16136
**kwargs: Callable[[], V],
162-
) -> None:
163-
super().__init__(*item, **kwargs) # type: ignore[arg-type]
37+
) -> LazyFrozenDict[K, V]:
38+
return super().__new__(cls, *item, **kwargs) # type: ignore[arg-type]
16439

16540
def __getitem__(self, k: K) -> V:
16641
return self._get_value(k)
16742

43+
@overload
44+
def get(self, key: K, /) -> V | None: ...
45+
@overload
46+
def get(self, key: K, /, default: T = ...) -> V | T: ...
47+
48+
def get(self, key: K, /, default: T | None = None) -> V | T | None:
49+
try:
50+
return self._get_value(key)
51+
except KeyError:
52+
return default
53+
16854
@memoized_method
16955
def _get_value(self, k: K) -> V:
170-
return cast("Callable[[], V]", self._data[k])()
56+
return cast("Callable[[], V]", super().__getitem__(k))()
57+
58+
def items(self) -> ItemsView[K, V]:
59+
return ItemsView(self)
60+
61+
def values(self) -> ValuesView[V]:
62+
return ValuesView(self)

0 commit comments

Comments
 (0)