Skip to content

Commit 5d1b897

Browse files
committed
Auto merge of rust-lang#146331 - RalfJung:copy-prov-repeat, r=oli-obk
interpret: copy_provenance: avoid large intermediate buffer for large repeat counts Copying provenance worked in this odd way where the "preparation" phase (which is supposed to just extract the necessary information from the source range) already did all the work of repeating the result N times for the target range. This was needed to use the existing `insert_presorted` function on `SortedMap`. This PR generalizes `insert_presorted` so that we can avoid this odd structure on copy-provenance, and maybe even improve performance.
2 parents ce6daf3 + 7abbc9c commit 5d1b897

File tree

6 files changed

+88
-82
lines changed

6 files changed

+88
-82
lines changed

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15041504
// This will also error if copying partial provenance is not supported.
15051505
let provenance = src_alloc
15061506
.provenance()
1507-
.prepare_copy(src_range, dest_offset, num_copies, self)
1507+
.prepare_copy(src_range, self)
15081508
.map_err(|e| e.to_interp_error(src_alloc_id))?;
15091509
// Prepare a copy of the initialization mask.
15101510
let init = src_alloc.init_mask().prepare_copy(src_range);
@@ -1590,7 +1590,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15901590
num_copies,
15911591
);
15921592
// copy the provenance to the destination
1593-
dest_alloc.provenance_apply_copy(provenance);
1593+
dest_alloc.provenance_apply_copy(provenance, alloc_range(dest_offset, size), num_copies);
15941594

15951595
interp_ok(())
15961596
}

compiler/rustc_data_structures/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#![feature(sized_hierarchy)]
3535
#![feature(test)]
3636
#![feature(thread_id_value)]
37+
#![feature(trusted_len)]
3738
#![feature(type_alias_impl_trait)]
3839
#![feature(unwrap_infallible)]
3940
// tidy-alphabetical-end

compiler/rustc_data_structures/src/sorted_map.rs

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::borrow::Borrow;
22
use std::cmp::Ordering;
33
use std::fmt::Debug;
4+
use std::iter::TrustedLen;
45
use std::mem;
56
use std::ops::{Bound, Index, IndexMut, RangeBounds};
67

@@ -215,36 +216,40 @@ impl<K: Ord, V> SortedMap<K, V> {
215216
/// It is up to the caller to make sure that the elements are sorted by key
216217
/// and that there are no duplicates.
217218
#[inline]
218-
pub fn insert_presorted(&mut self, elements: Vec<(K, V)>) {
219-
if elements.is_empty() {
219+
pub fn insert_presorted(
220+
&mut self,
221+
// We require `TrustedLen` to ensure that the `splice` below is actually efficient.
222+
mut elements: impl Iterator<Item = (K, V)> + DoubleEndedIterator + TrustedLen,
223+
) {
224+
let Some(first) = elements.next() else {
220225
return;
221-
}
222-
223-
debug_assert!(elements.array_windows().all(|[fst, snd]| fst.0 < snd.0));
226+
};
224227

225-
let start_index = self.lookup_index_for(&elements[0].0);
228+
let start_index = self.lookup_index_for(&first.0);
226229

227230
let elements = match start_index {
228231
Ok(index) => {
229-
let mut elements = elements.into_iter();
230-
self.data[index] = elements.next().unwrap();
231-
elements
232+
self.data[index] = first; // overwrite first element
233+
elements.chain(None) // insert the rest below
232234
}
233235
Err(index) => {
234-
if index == self.data.len() || elements.last().unwrap().0 < self.data[index].0 {
236+
let last = elements.next_back();
237+
if index == self.data.len()
238+
|| last.as_ref().is_none_or(|l| l.0 < self.data[index].0)
239+
{
235240
// We can copy the whole range without having to mix with
236241
// existing elements.
237-
self.data.splice(index..index, elements);
242+
self.data
243+
.splice(index..index, std::iter::once(first).chain(elements).chain(last));
238244
return;
239245
}
240246

241-
let mut elements = elements.into_iter();
242-
self.data.insert(index, elements.next().unwrap());
243-
elements
247+
self.data.insert(index, first);
248+
elements.chain(last) // insert the rest below
244249
}
245250
};
246251

247-
// Insert the rest
252+
// Insert the rest. This is super inefficicent since each insertion copies the entire tail.
248253
for (k, v) in elements {
249254
self.insert(k, v);
250255
}

compiler/rustc_data_structures/src/sorted_map/tests.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ fn test_insert_presorted_non_overlapping() {
171171
map.insert(2, 0);
172172
map.insert(8, 0);
173173

174-
map.insert_presorted(vec![(3, 0), (7, 0)]);
174+
map.insert_presorted(vec![(3, 0), (7, 0)].into_iter());
175175

176176
let expected = vec![2, 3, 7, 8];
177177
assert_eq!(keys(map), expected);
@@ -183,7 +183,7 @@ fn test_insert_presorted_first_elem_equal() {
183183
map.insert(2, 2);
184184
map.insert(8, 8);
185185

186-
map.insert_presorted(vec![(2, 0), (7, 7)]);
186+
map.insert_presorted(vec![(2, 0), (7, 7)].into_iter());
187187

188188
let expected = vec![(2, 0), (7, 7), (8, 8)];
189189
assert_eq!(elements(map), expected);
@@ -195,7 +195,7 @@ fn test_insert_presorted_last_elem_equal() {
195195
map.insert(2, 2);
196196
map.insert(8, 8);
197197

198-
map.insert_presorted(vec![(3, 3), (8, 0)]);
198+
map.insert_presorted(vec![(3, 3), (8, 0)].into_iter());
199199

200200
let expected = vec![(2, 2), (3, 3), (8, 0)];
201201
assert_eq!(elements(map), expected);
@@ -207,7 +207,7 @@ fn test_insert_presorted_shuffle() {
207207
map.insert(2, 2);
208208
map.insert(7, 7);
209209

210-
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)]);
210+
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)].into_iter());
211211

212212
let expected = vec![(1, 1), (2, 2), (3, 3), (7, 7), (8, 8)];
213213
assert_eq!(elements(map), expected);
@@ -219,7 +219,7 @@ fn test_insert_presorted_at_end() {
219219
map.insert(1, 1);
220220
map.insert(2, 2);
221221

222-
map.insert_presorted(vec![(3, 3), (8, 8)]);
222+
map.insert_presorted(vec![(3, 3), (8, 8)].into_iter());
223223

224224
let expected = vec![(1, 1), (2, 2), (3, 3), (8, 8)];
225225
assert_eq!(elements(map), expected);

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,13 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
849849
///
850850
/// This is dangerous to use as it can violate internal `Allocation` invariants!
851851
/// It only exists to support an efficient implementation of `mem_copy_repeatedly`.
852-
pub fn provenance_apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
853-
self.provenance.apply_copy(copy)
852+
pub fn provenance_apply_copy(
853+
&mut self,
854+
copy: ProvenanceCopy<Prov>,
855+
range: AllocRange,
856+
repeat: u64,
857+
) {
858+
self.provenance.apply_copy(copy, range, repeat)
854859
}
855860

856861
/// Applies a previously prepared copy of the init mask.

compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs

Lines changed: 53 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -278,90 +278,78 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
278278

279279
/// A partial, owned list of provenance to transfer into another allocation.
280280
///
281-
/// Offsets are already adjusted to the destination allocation.
281+
/// Offsets are relative to the beginning of the copied range.
282282
pub struct ProvenanceCopy<Prov> {
283-
dest_ptrs: Option<Box<[(Size, Prov)]>>,
284-
dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
283+
ptrs: Box<[(Size, Prov)]>,
284+
bytes: Box<[(Size, (Prov, u8))]>,
285285
}
286286

287287
impl<Prov: Provenance> ProvenanceMap<Prov> {
288288
pub fn prepare_copy(
289289
&self,
290-
src: AllocRange,
291-
dest: Size,
292-
count: u64,
290+
range: AllocRange,
293291
cx: &impl HasDataLayout,
294292
) -> AllocResult<ProvenanceCopy<Prov>> {
295-
let shift_offset = move |idx, offset| {
296-
// compute offset for current repetition
297-
let dest_offset = dest + src.size * idx; // `Size` operations
298-
// shift offsets from source allocation to destination allocation
299-
(offset - src.start) + dest_offset // `Size` operations
300-
};
293+
let shift_offset = move |offset| offset - range.start;
301294
let ptr_size = cx.data_layout().pointer_size();
302295

303296
// # Pointer-sized provenances
304297
// Get the provenances that are entirely within this range.
305298
// (Different from `range_get_ptrs` which asks if they overlap the range.)
306299
// Only makes sense if we are copying at least one pointer worth of bytes.
307-
let mut dest_ptrs_box = None;
308-
if src.size >= ptr_size {
309-
let adjusted_end = Size::from_bytes(src.end().bytes() - (ptr_size.bytes() - 1));
310-
let ptrs = self.ptrs.range(src.start..adjusted_end);
311-
// If `count` is large, this is rather wasteful -- we are allocating a big array here, which
312-
// is mostly filled with redundant information since it's just N copies of the same `Prov`s
313-
// at slightly adjusted offsets. The reason we do this is so that in `mark_provenance_range`
314-
// we can use `insert_presorted`. That wouldn't work with an `Iterator` that just produces
315-
// the right sequence of provenance for all N copies.
316-
// Basically, this large array would have to be created anyway in the target allocation.
317-
let mut dest_ptrs = Vec::with_capacity(ptrs.len() * (count as usize));
318-
for i in 0..count {
319-
dest_ptrs
320-
.extend(ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
321-
}
322-
debug_assert_eq!(dest_ptrs.len(), dest_ptrs.capacity());
323-
dest_ptrs_box = Some(dest_ptrs.into_boxed_slice());
300+
let mut ptrs_box: Box<[_]> = Box::new([]);
301+
if range.size >= ptr_size {
302+
let adjusted_end = Size::from_bytes(range.end().bytes() - (ptr_size.bytes() - 1));
303+
let ptrs = self.ptrs.range(range.start..adjusted_end);
304+
ptrs_box = ptrs.iter().map(|&(offset, reloc)| (shift_offset(offset), reloc)).collect();
324305
};
325306

326307
// # Byte-sized provenances
327308
// This includes the existing bytewise provenance in the range, and ptr provenance
328309
// that overlaps with the begin/end of the range.
329-
let mut dest_bytes_box = None;
330-
let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
331-
let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
310+
let mut bytes_box: Box<[_]> = Box::new([]);
311+
let begin_overlap = self.range_ptrs_get(alloc_range(range.start, Size::ZERO), cx).first();
312+
let end_overlap = self.range_ptrs_get(alloc_range(range.end(), Size::ZERO), cx).first();
332313
// We only need to go here if there is some overlap or some bytewise provenance.
333314
if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
334315
let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
335316
// First, if there is a part of a pointer at the start, add that.
336317
if let Some(entry) = begin_overlap {
337318
trace!("start overlapping entry: {entry:?}");
338-
// For really small copies, make sure we don't run off the end of the `src` range.
339-
let entry_end = cmp::min(entry.0 + ptr_size, src.end());
340-
for offset in src.start..entry_end {
341-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
319+
// For really small copies, make sure we don't run off the end of the range.
320+
let entry_end = cmp::min(entry.0 + ptr_size, range.end());
321+
for offset in range.start..entry_end {
322+
bytes.push((shift_offset(offset), (entry.1, (offset - entry.0).bytes() as u8)));
342323
}
343324
} else {
344325
trace!("no start overlapping entry");
345326
}
346327

347328
// Then the main part, bytewise provenance from `self.bytes`.
348-
bytes.extend(self.range_bytes_get(src));
329+
bytes.extend(
330+
self.range_bytes_get(range)
331+
.iter()
332+
.map(|&(offset, reloc)| (shift_offset(offset), reloc)),
333+
);
349334

350335
// And finally possibly parts of a pointer at the end.
351336
if let Some(entry) = end_overlap {
352337
trace!("end overlapping entry: {entry:?}");
353-
// For really small copies, make sure we don't start before `src` does.
354-
let entry_start = cmp::max(entry.0, src.start);
355-
for offset in entry_start..src.end() {
338+
// For really small copies, make sure we don't start before `range` does.
339+
let entry_start = cmp::max(entry.0, range.start);
340+
for offset in entry_start..range.end() {
356341
if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
357342
// The last entry, if it exists, has a lower offset than us, so we
358343
// can add it at the end and remain sorted.
359-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
344+
bytes.push((
345+
shift_offset(offset),
346+
(entry.1, (offset - entry.0).bytes() as u8),
347+
));
360348
} else {
361349
// There already is an entry for this offset in there! This can happen when the
362350
// start and end range checks actually end up hitting the same pointer, so we
363351
// already added this in the "pointer at the start" part above.
364-
assert!(entry.0 <= src.start);
352+
assert!(entry.0 <= range.start);
365353
}
366354
}
367355
} else {
@@ -372,33 +360,40 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
372360
if !bytes.is_empty() && !Prov::OFFSET_IS_ADDR {
373361
// FIXME(#146291): We need to ensure that we don't mix different pointers with
374362
// the same provenance.
375-
return Err(AllocError::ReadPartialPointer(src.start));
363+
return Err(AllocError::ReadPartialPointer(range.start));
376364
}
377365

378366
// And again a buffer for the new list on the target side.
379-
let mut dest_bytes = Vec::with_capacity(bytes.len() * (count as usize));
380-
for i in 0..count {
381-
dest_bytes
382-
.extend(bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
383-
}
384-
debug_assert_eq!(dest_bytes.len(), dest_bytes.capacity());
385-
dest_bytes_box = Some(dest_bytes.into_boxed_slice());
367+
bytes_box = bytes.into_boxed_slice();
386368
}
387369

388-
Ok(ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box })
370+
Ok(ProvenanceCopy { ptrs: ptrs_box, bytes: bytes_box })
389371
}
390372

391373
/// Applies a provenance copy.
392374
/// The affected range, as defined in the parameters to `prepare_copy` is expected
393375
/// to be clear of provenance.
394-
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
395-
if let Some(dest_ptrs) = copy.dest_ptrs {
396-
self.ptrs.insert_presorted(dest_ptrs.into());
376+
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>, range: AllocRange, repeat: u64) {
377+
let shift_offset = |idx: u64, offset: Size| offset + range.start + idx * range.size;
378+
if !copy.ptrs.is_empty() {
379+
// We want to call `insert_presorted` only once so that, if possible, the entries
380+
// after the range we insert are moved back only once.
381+
let chunk_len = copy.ptrs.len() as u64;
382+
self.ptrs.insert_presorted((0..chunk_len * repeat).map(|i| {
383+
let chunk = i / chunk_len;
384+
let (offset, reloc) = copy.ptrs[(i % chunk_len) as usize];
385+
(shift_offset(chunk, offset), reloc)
386+
}));
397387
}
398-
if let Some(dest_bytes) = copy.dest_bytes
399-
&& !dest_bytes.is_empty()
400-
{
401-
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
388+
if !copy.bytes.is_empty() {
389+
let chunk_len = copy.bytes.len() as u64;
390+
self.bytes.get_or_insert_with(Box::default).insert_presorted(
391+
(0..chunk_len * repeat).map(|i| {
392+
let chunk = i / chunk_len;
393+
let (offset, reloc) = copy.bytes[(i % chunk_len) as usize];
394+
(shift_offset(chunk, offset), reloc)
395+
}),
396+
);
402397
}
403398
}
404399
}

0 commit comments

Comments
 (0)