Skip to content

Commit 88de572

Browse files
authored
[ty] Garbage-collect reachability constraints (astral-sh#19414)
This is a follow-on to astral-sh#19410 that further reduces the memory usage of our reachability constraints. When finishing the building of a use-def map, we walk through all of the "final" states and mark only those reachability constraints as "used". We then throw away the interior TDD nodes of any reachability constraints that weren't marked as used. (This helps because we build up quite a few intermediate TDD nodes when constructing complex reachability constraints. These nodes can never be accessed if they were _only_ used as an intermediate TDD node. The marking step ensures that we keep any nodes that ended up being referred to in some accessible use-def map state.)
1 parent b8dec79 commit 88de572

File tree

9 files changed

+252
-10
lines changed

9 files changed

+252
-10
lines changed

Cargo.lock

Lines changed: 40 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ assert_fs = { version = "1.1.0" }
5757
argfile = { version = "0.2.0" }
5858
bincode = { version = "2.0.0" }
5959
bitflags = { version = "2.5.0" }
60+
bitvec = { version = "1.0.1", default-features = false, features = [
61+
"alloc",
62+
] }
6063
bstr = { version = "1.9.1" }
6164
cachedir = { version = "0.3.1" }
6265
camino = { version = "1.1.7" }

crates/ty_python_semantic/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ ty_static = { workspace = true }
2626

2727
anyhow = { workspace = true }
2828
bitflags = { workspace = true }
29+
bitvec = { workspace = true }
2930
camino = { workspace = true }
3031
colored = { workspace = true }
3132
compact_str = { workspace = true }

crates/ty_python_semantic/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod node_key;
3434
pub(crate) mod place;
3535
mod program;
3636
mod python_platform;
37+
mod rank;
3738
pub mod semantic_index;
3839
mod semantic_model;
3940
pub(crate) mod site_packages;
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//! A boxed bit slice that supports a constant-time `rank` operation.
2+
3+
use bitvec::prelude::{BitBox, Msb0};
4+
use get_size2::GetSize;
5+
6+
/// A boxed bit slice that supports a constant-time `rank` operation.
7+
///
8+
/// This can be used to "shrink" a large vector, where you only need to keep certain elements, and
9+
/// you want to continue to use the index in the large vector to identify each element.
10+
///
11+
/// First you create a new smaller vector, keeping only the elements of the large vector that you
12+
/// care about. Now you need a way to translate an index into the large vector (which no longer
13+
/// exists) into the corresponding index into the smaller vector. To do that, you create a bit
14+
/// slice, containing a bit for every element of the original large vector. Each bit in the bit
15+
/// slice indicates whether that element of the large vector was kept in the smaller vector. And
16+
/// the `rank` of the bit gives us the index of the element in the smaller vector.
17+
///
18+
/// However, the naive implementation of `rank` is O(n) in the size of the bit slice. To address
19+
/// that, we use a standard trick: we divide the bit slice into 64-bit chunks, and when
20+
/// constructing the bit slice, precalculate the rank of the first bit in each chunk. Then, to
21+
/// calculate the rank of an arbitrary bit, we first grab the precalculated rank of the chunk that
22+
/// bit belongs to, and add the rank of the bit within its (fixed-sized) chunk.
23+
///
24+
/// This trick adds O(1.5) bits of overhead per large vector element on 64-bit platforms, and O(2)
25+
/// bits of overhead on 32-bit platforms.
26+
#[derive(Clone, Debug, Eq, PartialEq, GetSize)]
27+
pub(crate) struct RankBitBox {
28+
#[get_size(size_fn = bit_box_size)]
29+
bits: BitBox<Chunk, Msb0>,
30+
chunk_ranks: Box<[u32]>,
31+
}
32+
33+
fn bit_box_size(bits: &BitBox<Chunk, Msb0>) -> usize {
34+
bits.as_raw_slice().get_heap_size()
35+
}
36+
37+
// bitvec does not support `u64` as a Store type on 32-bit platforms
38+
#[cfg(target_pointer_width = "64")]
39+
type Chunk = u64;
40+
#[cfg(not(target_pointer_width = "64"))]
41+
type Chunk = u32;
42+
43+
const CHUNK_SIZE: usize = Chunk::BITS as usize;
44+
45+
impl RankBitBox {
46+
pub(crate) fn from_bits(iter: impl Iterator<Item = bool>) -> Self {
47+
let bits: BitBox<Chunk, Msb0> = iter.collect();
48+
let chunk_ranks = bits
49+
.as_raw_slice()
50+
.iter()
51+
.scan(0u32, |rank, chunk| {
52+
let result = *rank;
53+
*rank += chunk.count_ones();
54+
Some(result)
55+
})
56+
.collect();
57+
Self { bits, chunk_ranks }
58+
}
59+
60+
#[inline]
61+
pub(crate) fn get_bit(&self, index: usize) -> Option<bool> {
62+
self.bits.get(index).map(|bit| *bit)
63+
}
64+
65+
/// Returns the number of bits _before_ (and not including) the given index that are set.
66+
#[inline]
67+
pub(crate) fn rank(&self, index: usize) -> u32 {
68+
let chunk_index = index / CHUNK_SIZE;
69+
let index_within_chunk = index % CHUNK_SIZE;
70+
let chunk_rank = self.chunk_ranks[chunk_index];
71+
if index_within_chunk == 0 {
72+
return chunk_rank;
73+
}
74+
75+
// To calculate the rank within the bit's chunk, we zero out the requested bit and every
76+
// bit to the right, then count the number of 1s remaining (i.e., to the left of the
77+
// requested bit).
78+
let chunk = self.bits.as_raw_slice()[chunk_index];
79+
let chunk_mask = Chunk::MAX << (CHUNK_SIZE - index_within_chunk);
80+
let rank_within_chunk = (chunk & chunk_mask).count_ones();
81+
chunk_rank + rank_within_chunk
82+
}
83+
}

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
10211021

10221022
assert_eq!(&self.current_assignments, &[]);
10231023

1024+
for scope in &self.scopes {
1025+
if let Some(parent) = scope.parent() {
1026+
self.use_def_maps[parent]
1027+
.reachability_constraints
1028+
.mark_used(scope.reachability());
1029+
}
1030+
}
1031+
10241032
let mut place_tables: IndexVec<_, _> = self
10251033
.place_tables
10261034
.into_iter()

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ use rustc_hash::FxHashMap;
201201
use crate::Db;
202202
use crate::dunder_all::dunder_all_names;
203203
use crate::place::{RequiresExplicitReExport, imported_symbol};
204+
use crate::rank::RankBitBox;
204205
use crate::semantic_index::expression::Expression;
205206
use crate::semantic_index::place_table;
206207
use crate::semantic_index::predicate::{
@@ -283,6 +284,10 @@ impl ScopedReachabilityConstraintId {
283284
fn is_terminal(self) -> bool {
284285
self.0 >= SMALLEST_TERMINAL.0
285286
}
287+
288+
fn as_u32(self) -> u32 {
289+
self.0
290+
}
286291
}
287292

288293
impl Idx for ScopedReachabilityConstraintId {
@@ -309,12 +314,18 @@ const SMALLEST_TERMINAL: ScopedReachabilityConstraintId = ALWAYS_FALSE;
309314
/// A collection of reachability constraints for a given scope.
310315
#[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
311316
pub(crate) struct ReachabilityConstraints {
312-
interiors: IndexVec<ScopedReachabilityConstraintId, InteriorNode>,
317+
/// The interior TDD nodes that were marked as used when being built.
318+
used_interiors: Box<[InteriorNode]>,
319+
/// A bit vector indicating which interior TDD nodes were marked as used. This is indexed by
320+
/// the node's [`ScopedReachabilityConstraintId`]. The rank of the corresponding bit gives the
321+
/// index of that node in the `used_interiors` vector.
322+
used_indices: RankBitBox,
313323
}
314324

315325
#[derive(Debug, Default, PartialEq, Eq)]
316326
pub(crate) struct ReachabilityConstraintsBuilder {
317327
interiors: IndexVec<ScopedReachabilityConstraintId, InteriorNode>,
328+
interior_used: IndexVec<ScopedReachabilityConstraintId, bool>,
318329
interior_cache: FxHashMap<InteriorNode, ScopedReachabilityConstraintId>,
319330
not_cache: FxHashMap<ScopedReachabilityConstraintId, ScopedReachabilityConstraintId>,
320331
and_cache: FxHashMap<
@@ -334,11 +345,28 @@ pub(crate) struct ReachabilityConstraintsBuilder {
334345
}
335346

336347
impl ReachabilityConstraintsBuilder {
337-
pub(crate) fn build(mut self) -> ReachabilityConstraints {
338-
self.interiors.shrink_to_fit();
339-
348+
pub(crate) fn build(self) -> ReachabilityConstraints {
349+
let used_indices = RankBitBox::from_bits(self.interior_used.iter().copied());
350+
let used_interiors = (self.interiors.into_iter())
351+
.zip(self.interior_used)
352+
.filter_map(|(interior, used)| used.then_some(interior))
353+
.collect();
340354
ReachabilityConstraints {
341-
interiors: self.interiors,
355+
used_interiors,
356+
used_indices,
357+
}
358+
}
359+
360+
/// Marks that a particular TDD node is used. This lets us throw away interior nodes that were
361+
/// only calculated for intermediate values, and which don't need to be included in the final
362+
/// built result.
363+
pub(crate) fn mark_used(&mut self, node: ScopedReachabilityConstraintId) {
364+
if !node.is_terminal() && !self.interior_used[node] {
365+
self.interior_used[node] = true;
366+
let node = self.interiors[node];
367+
self.mark_used(node.if_true);
368+
self.mark_used(node.if_ambiguous);
369+
self.mark_used(node.if_false);
342370
}
343371
}
344372

@@ -370,10 +398,10 @@ impl ReachabilityConstraintsBuilder {
370398
return node.if_true;
371399
}
372400

373-
*self
374-
.interior_cache
375-
.entry(node)
376-
.or_insert_with(|| self.interiors.push(node))
401+
*self.interior_cache.entry(node).or_insert_with(|| {
402+
self.interior_used.push(false);
403+
self.interiors.push(node)
404+
})
377405
}
378406

379407
/// Adds a new reachability constraint that checks a single [`Predicate`].
@@ -581,7 +609,21 @@ impl ReachabilityConstraints {
581609
ALWAYS_TRUE => return Truthiness::AlwaysTrue,
582610
AMBIGUOUS => return Truthiness::Ambiguous,
583611
ALWAYS_FALSE => return Truthiness::AlwaysFalse,
584-
_ => self.interiors[id],
612+
_ => {
613+
// `id` gives us the index of this node in the IndexVec that we used when
614+
// constructing this BDD. When finalizing the builder, we threw away any
615+
// interior nodes that weren't marked as used. The `used_indices` bit vector
616+
// lets us verify that this node was marked as used, and the rank of that bit
617+
// in the bit vector tells us where this node lives in the "condensed"
618+
// `used_interiors` vector.
619+
let raw_index = id.as_u32() as usize;
620+
debug_assert!(
621+
self.used_indices.get_bit(raw_index).unwrap_or(false),
622+
"all used reachability constraints should have been marked as used",
623+
);
624+
let index = self.used_indices.rank(raw_index) as usize;
625+
self.used_interiors[index]
626+
}
585627
};
586628
let predicate = &predicates[node.atom];
587629
match Self::analyze_single(db, predicate) {

crates/ty_python_semantic/src/semantic_index/use_def.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,41 @@ impl<'db> UseDefMapBuilder<'db> {
11181118
.add_or_constraint(self.reachability, snapshot.reachability);
11191119
}
11201120

1121+
fn mark_reachability_constraints(&mut self) {
1122+
// We only walk the fields that are copied through to the UseDefMap when we finish building
1123+
// it.
1124+
for bindings in &mut self.bindings_by_use {
1125+
bindings.finish(&mut self.reachability_constraints);
1126+
}
1127+
for constraint in self.node_reachability.values() {
1128+
self.reachability_constraints.mark_used(*constraint);
1129+
}
1130+
for place_state in &mut self.place_states {
1131+
place_state.finish(&mut self.reachability_constraints);
1132+
}
1133+
for reachable_definition in &mut self.reachable_definitions {
1134+
reachable_definition
1135+
.bindings
1136+
.finish(&mut self.reachability_constraints);
1137+
reachable_definition
1138+
.declarations
1139+
.finish(&mut self.reachability_constraints);
1140+
}
1141+
for declarations in self.declarations_by_binding.values_mut() {
1142+
declarations.finish(&mut self.reachability_constraints);
1143+
}
1144+
for bindings in self.bindings_by_definition.values_mut() {
1145+
bindings.finish(&mut self.reachability_constraints);
1146+
}
1147+
for eager_snapshot in &mut self.eager_snapshots {
1148+
eager_snapshot.finish(&mut self.reachability_constraints);
1149+
}
1150+
self.reachability_constraints.mark_used(self.reachability);
1151+
}
1152+
11211153
pub(super) fn finish(mut self) -> UseDefMap<'db> {
1154+
self.mark_reachability_constraints();
1155+
11221156
self.all_definitions.shrink_to_fit();
11231157
self.place_states.shrink_to_fit();
11241158
self.reachable_definitions.shrink_to_fit();

0 commit comments

Comments
 (0)