Skip to content

Commit 271038b

Browse files
[perf] Use Vec instead of BTreeMap in tree_from_memory (#1338)
* Build memory tree presumably faster * Construct all-zero nodes in advance * Revert "Construct all-zero nodes in advance" This reverts commit 9f7dc7f. * Cosmetic change * Calculate the capacity first, then reserve exactly this much * Revert "Calculate the capacity first, then reserve exactly this much" This reverts commit 3ee49b3. * Make pagedvec iter return &T, not T * Calculate the capacity first, then reserve exactly this much * Revert "Make pagedvec iter return &T, not T" This reverts commit f1e8403. * Revert "Calculate the capacity first, then reserve exactly this much" This reverts commit c719be1. * Cache only all-zero leaf * Avoid building `Vec<(_, [F; CHUNK])>` * Revert "Cache only all-zero leaf" This reverts commit 98ae517. * Revert "Revert "Cache only all-zero leaf"" This reverts commit e247f5b. --------- Co-authored-by: Jonathan Wang <[email protected]>
1 parent 9d1c716 commit 271038b

File tree

1 file changed

+83
-24
lines changed
  • crates/vm/src/system/memory/tree

1 file changed

+83
-24
lines changed

crates/vm/src/system/memory/tree/mod.rs

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pub mod public_values;
22

3-
use std::{collections::BTreeMap, sync::Arc};
3+
use std::{ops::Range, sync::Arc};
44

55
use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*};
66
use MemoryNode::*;
@@ -66,23 +66,71 @@ impl<const CHUNK: usize, F: PrimeField32> MemoryNode<CHUNK, F> {
6666
}
6767

6868
fn from_memory(
69-
memory: &BTreeMap<u64, [F; CHUNK]>,
70-
height: usize,
69+
memory: &[(u64, F)],
70+
lookup_range: Range<usize>,
71+
length: u64,
7172
from: u64,
7273
hasher: &(impl Hasher<CHUNK, F> + Sync),
74+
zero_leaf: &MemoryNode<CHUNK, F>,
7375
) -> MemoryNode<CHUNK, F> {
74-
let mut range = memory.range(from..from + (1 << height));
75-
if height == 0 {
76-
let values = *memory.get(&from).unwrap_or(&[F::ZERO; CHUNK]);
77-
MemoryNode::new_leaf(hasher.hash(&values))
78-
} else if range.next().is_none() {
76+
if length == CHUNK as u64 {
77+
if lookup_range.is_empty() {
78+
zero_leaf.clone()
79+
} else {
80+
debug_assert_eq!(memory[lookup_range.start].0, from);
81+
let mut values = [F::ZERO; CHUNK];
82+
for (index, value) in memory[lookup_range].iter() {
83+
values[(index % CHUNK as u64) as usize] = *value;
84+
}
85+
MemoryNode::new_leaf(hasher.hash(&values))
86+
}
87+
} else if lookup_range.is_empty() {
7988
let leaf_value = hasher.hash(&[F::ZERO; CHUNK]);
80-
MemoryNode::construct_uniform(height, leaf_value, hasher)
89+
MemoryNode::construct_uniform(
90+
(length / CHUNK as u64).trailing_zeros() as usize,
91+
leaf_value,
92+
hasher,
93+
)
8194
} else {
82-
let midpoint = from + (1 << (height - 1));
95+
let midpoint = from + length / 2;
96+
let mid = {
97+
let mut left = lookup_range.start;
98+
let mut right = lookup_range.end;
99+
if memory[left].0 >= midpoint {
100+
left
101+
} else {
102+
while left + 1 < right {
103+
let mid = left + (right - left) / 2;
104+
if memory[mid].0 < midpoint {
105+
left = mid;
106+
} else {
107+
right = mid;
108+
}
109+
}
110+
right
111+
}
112+
};
83113
let (left, right) = join(
84-
|| Self::from_memory(memory, height - 1, from, hasher),
85-
|| Self::from_memory(memory, height - 1, midpoint, hasher),
114+
|| {
115+
Self::from_memory(
116+
memory,
117+
lookup_range.start..mid,
118+
length >> 1,
119+
from,
120+
hasher,
121+
zero_leaf,
122+
)
123+
},
124+
|| {
125+
Self::from_memory(
126+
memory,
127+
mid..lookup_range.end,
128+
length >> 1,
129+
midpoint,
130+
hasher,
131+
zero_leaf,
132+
)
133+
},
86134
);
87135
NonLeaf {
88136
hash: hasher.compress(&left.hash(), &right.hash()),
@@ -97,22 +145,33 @@ impl<const CHUNK: usize, F: PrimeField32> MemoryNode<CHUNK, F> {
97145
memory: &MemoryImage<F>,
98146
hasher: &(impl Hasher<CHUNK, F> + Sync),
99147
) -> MemoryNode<CHUNK, F> {
100-
// Construct a BTreeMap that includes the address space in the label calculation,
148+
// Construct a Vec that includes the address space in the label calculation,
101149
// representing the entire memory tree.
102-
let mut memory_partition = BTreeMap::new();
103-
for ((address_space, pointer), value) in memory.items() {
104-
let label = (address_space, pointer / CHUNK as u32);
105-
let index = memory_dimensions.label_to_index(label);
106-
let chunk = memory_partition
107-
.entry(index)
108-
.or_insert_with(|| [F::ZERO; CHUNK]);
109-
chunk[(pointer % CHUNK as u32) as usize] = value;
110-
}
150+
let memory_items = memory
151+
.items()
152+
.filter(|((_, ptr), _)| *ptr as usize / CHUNK < (1 << memory_dimensions.address_height))
153+
.map(|((address_space, pointer), value)| {
154+
(
155+
memory_dimensions.label_to_index((address_space, pointer / CHUNK as u32))
156+
* CHUNK as u64
157+
+ (pointer % CHUNK as u32) as u64,
158+
value,
159+
)
160+
})
161+
.collect::<Vec<_>>();
162+
debug_assert!(memory_items.is_sorted_by_key(|(addr, _)| addr));
163+
debug_assert!(
164+
memory_items.last().map_or(0, |(addr, _)| *addr)
165+
< ((CHUNK as u64) << memory_dimensions.overall_height())
166+
);
167+
let zero_leaf = MemoryNode::new_leaf(hasher.hash(&[F::ZERO; CHUNK]));
111168
Self::from_memory(
112-
&memory_partition,
113-
memory_dimensions.overall_height(),
169+
&memory_items,
170+
0..memory_items.len(),
171+
(CHUNK as u64) << memory_dimensions.overall_height(),
114172
0,
115173
hasher,
174+
&zero_leaf,
116175
)
117176
}
118177
}

0 commit comments

Comments
 (0)