|
| 1 | +"""Benchmark: Rust FrozenOrderedSet vs Python FrozenOrderedSet.""" |
| 2 | + |
| 3 | +import sys |
| 4 | +import timeit |
| 5 | +from collections.abc import Hashable, Iterable, Iterator |
| 6 | +from typing import AbstractSet, Any, TypeVar |
| 7 | + |
| 8 | +sys.path.insert(0, "src/python") |
| 9 | + |
| 10 | +from pants.engine.internals.native_engine import FrozenOrderedSet as RustFrozenOrderedSet |
| 11 | + |
| 12 | +T = TypeVar("T") |
| 13 | + |
| 14 | + |
| 15 | +class PyFrozenOrderedSet(AbstractSet[T], Hashable): |
| 16 | + """The old pure-Python FrozenOrderedSet (pre-port).""" |
| 17 | + |
| 18 | + def __init__(self, iterable=None): |
| 19 | + self._items = dict.fromkeys(iterable) if iterable else {} |
| 20 | + self._hash = None |
| 21 | + |
| 22 | + def __len__(self): |
| 23 | + return len(self._items) |
| 24 | + |
| 25 | + def __contains__(self, key): |
| 26 | + return key in self._items |
| 27 | + |
| 28 | + def __iter__(self) -> Iterator: |
| 29 | + return iter(self._items) |
| 30 | + |
| 31 | + def __reversed__(self): |
| 32 | + return reversed(tuple(self._items.keys())) |
| 33 | + |
| 34 | + def __eq__(self, other): |
| 35 | + if not isinstance(other, self.__class__): |
| 36 | + return NotImplemented |
| 37 | + return len(self._items) == len(other._items) and all( |
| 38 | + x == y for x, y in zip(self._items, other._items) |
| 39 | + ) |
| 40 | + |
| 41 | + def __hash__(self): |
| 42 | + if self._hash is None: |
| 43 | + self._hash = 0 |
| 44 | + for item in self._items.keys(): |
| 45 | + self._hash ^= hash(item) |
| 46 | + return self._hash |
| 47 | + |
| 48 | + def __repr__(self): |
| 49 | + return f"PyFrozenOrderedSet({list(self)!r})" |
| 50 | + |
| 51 | + def __bool__(self): |
| 52 | + return bool(self._items) |
| 53 | + |
| 54 | + def union(self, other): |
| 55 | + return self.__class__(list(self) + [x for x in other if x not in self._items]) |
| 56 | + |
| 57 | + def intersection(self, other): |
| 58 | + s = set(other) |
| 59 | + return self.__class__(x for x in self if x in s) |
| 60 | + |
| 61 | + def difference(self, other): |
| 62 | + s = set(other) |
| 63 | + return self.__class__(x for x in self if x not in s) |
| 64 | + |
| 65 | + def issubset(self, other): |
| 66 | + if len(self) > len(other): |
| 67 | + return False |
| 68 | + return all(item in other for item in self) |
| 69 | + |
| 70 | + |
| 71 | +WARMUP = 1000 |
| 72 | + |
| 73 | +def measure(stmt, number, globs): |
| 74 | + timeit.timeit(stmt, number=WARMUP, globals=globs) |
| 75 | + t = timeit.timeit(stmt, number=number, globals=globs) |
| 76 | + return t / number * 1_000_000 |
| 77 | + |
| 78 | + |
| 79 | +BENCHMARKS = [ |
| 80 | + ("Construction", "Cls(data)", lambda data, py, rs, **_: [ |
| 81 | + {"Cls": PyFrozenOrderedSet, "data": data}, |
| 82 | + {"Cls": RustFrozenOrderedSet, "data": data}, |
| 83 | + ]), |
| 84 | + ("hash()", "hash(fd)", lambda py, rs, **_: [ |
| 85 | + {"fd": py}, |
| 86 | + {"fd": rs}, |
| 87 | + ]), |
| 88 | + ("__contains__", "k in fd", lambda py, rs, mid, **_: [ |
| 89 | + {"fd": py, "k": mid}, |
| 90 | + {"fd": rs, "k": mid}, |
| 91 | + ]), |
| 92 | + ("__contains__ miss", "k in fd", lambda py, rs, **_: [ |
| 93 | + {"fd": py, "k": "MISSING"}, |
| 94 | + {"fd": rs, "k": "MISSING"}, |
| 95 | + ]), |
| 96 | + ("__eq__", "fd == fd2", lambda py, rs, py2, rs2, **_: [ |
| 97 | + {"fd": py, "fd2": py2}, |
| 98 | + {"fd": rs, "fd2": rs2}, |
| 99 | + ]), |
| 100 | + ("iteration", "list(fd)", lambda py, rs, **_: [ |
| 101 | + {"fd": py}, |
| 102 | + {"fd": rs}, |
| 103 | + ]), |
| 104 | + ("union", "fd.union(other)", lambda py, rs, py_other, rs_other, **_: [ |
| 105 | + {"fd": py, "other": py_other}, |
| 106 | + {"fd": rs, "other": rs_other}, |
| 107 | + ]), |
| 108 | + ("intersection", "fd.intersection(other)", lambda py, rs, py_other, rs_other, **_: [ |
| 109 | + {"fd": py, "other": py_other}, |
| 110 | + {"fd": rs, "other": rs_other}, |
| 111 | + ]), |
| 112 | + ("difference", "fd.difference(other)", lambda py, rs, py_other, rs_other, **_: [ |
| 113 | + {"fd": py, "other": py_other}, |
| 114 | + {"fd": rs, "other": rs_other}, |
| 115 | + ]), |
| 116 | + ("issubset", "small.issubset(fd)", lambda py, rs, py_small, rs_small, **_: [ |
| 117 | + {"small": py_small, "fd": py}, |
| 118 | + {"small": rs_small, "fd": rs}, |
| 119 | + ]), |
| 120 | + ("dict key", "d[fd]", lambda py, rs, **_: [ |
| 121 | + {"fd": py, "d": {py: 1}}, |
| 122 | + {"fd": rs, "d": {rs: 1}}, |
| 123 | + ]), |
| 124 | +] |
| 125 | + |
| 126 | +SMALL = list(range(5)) |
| 127 | +MEDIUM = list(range(20)) |
| 128 | +LARGE = list(range(200)) |
| 129 | + |
| 130 | +DATASETS = [("small (5)", SMALL), ("medium (20)", MEDIUM), ("large (200)", LARGE)] |
| 131 | + |
| 132 | +all_results: dict[str, dict[str, tuple[float, float]]] = {} |
| 133 | + |
| 134 | +for ds_name, data in DATASETS: |
| 135 | + print(f"\n{'=' * 60}") |
| 136 | + print(f" Dataset: {ds_name}") |
| 137 | + print(f"{'=' * 60}") |
| 138 | + |
| 139 | + py = PyFrozenOrderedSet(data) |
| 140 | + rs = RustFrozenOrderedSet(data) |
| 141 | + py2 = PyFrozenOrderedSet(data) |
| 142 | + rs2 = RustFrozenOrderedSet(data) |
| 143 | + half = data[:len(data) // 2] |
| 144 | + py_other = PyFrozenOrderedSet(half + list(range(1000, 1000 + len(half)))) |
| 145 | + rs_other = RustFrozenOrderedSet(half + list(range(1000, 1000 + len(half)))) |
| 146 | + py_small = PyFrozenOrderedSet(data[:3]) |
| 147 | + rs_small = RustFrozenOrderedSet(data[:3]) |
| 148 | + # Warm up lazy hashes |
| 149 | + for obj in (py, py2, py_other, py_small): |
| 150 | + hash(obj) |
| 151 | + n = 500_000 if len(data) <= 20 else 50_000 |
| 152 | + mid = data[len(data) // 2] |
| 153 | + |
| 154 | + ctx = dict(data=data, py=py, rs=rs, py2=py2, rs2=rs2, mid=mid, |
| 155 | + py_other=py_other, rs_other=rs_other, py_small=py_small, rs_small=rs_small) |
| 156 | + |
| 157 | + for bench_name, stmt, make_globs in BENCHMARKS: |
| 158 | + py_globs, rs_globs = make_globs(**ctx) |
| 159 | + py_us = measure(stmt, n, py_globs) |
| 160 | + rs_us = measure(stmt, n, rs_globs) |
| 161 | + print(f" {bench_name:.<20s} Python {py_us:8.3f} µs Rust {rs_us:8.3f} µs ({py_us / rs_us:.1f}x)") |
| 162 | + all_results.setdefault(bench_name, {})[ds_name] = (py_us, rs_us) |
| 163 | + |
| 164 | +ds_names = [name for name, _ in DATASETS] |
| 165 | +header = f" {'Operation':<20s}" + "".join(f" | {name:>12s}" for name in ds_names) |
| 166 | +sep = f" {'-'*20}" + "".join(f"-+-{'-'*12}" for _ in ds_names) |
| 167 | + |
| 168 | +print(f"\n{'=' * 60}") |
| 169 | +print(" Summary (Python / Rust speedup)") |
| 170 | +print(f"{'=' * 60}") |
| 171 | +print(header) |
| 172 | +print(sep) |
| 173 | +for bench_name, _, _ in BENCHMARKS: |
| 174 | + row = f" {bench_name:<20s}" |
| 175 | + for ds_name in ds_names: |
| 176 | + py_us, rs_us = all_results[bench_name][ds_name] |
| 177 | + ratio = py_us / rs_us |
| 178 | + row += f" | {ratio:11.1f}x" |
| 179 | + print(row) |
0 commit comments