Skip to content

Commit 64ea775

Browse files
committed
interpret: copy_provenance: avoid large intermediate buffer for large repeat counts
1 parent be8de5d commit 64ea775

File tree

6 files changed

+81
-81
lines changed

6 files changed

+81
-81
lines changed

compiler/rustc_const_eval/src/interpret/memory.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,7 +1503,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15031503
// This will also error if copying partial provenance is not supported.
15041504
let provenance = src_alloc
15051505
.provenance()
1506-
.prepare_copy(src_range, dest_offset, num_copies, self)
1506+
.prepare_copy(src_range, self)
15071507
.map_err(|e| e.to_interp_error(src_alloc_id))?;
15081508
// Prepare a copy of the initialization mask.
15091509
let init = src_alloc.init_mask().prepare_copy(src_range);
@@ -1589,7 +1589,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
15891589
num_copies,
15901590
);
15911591
// copy the provenance to the destination
1592-
dest_alloc.provenance_apply_copy(provenance);
1592+
dest_alloc.provenance_apply_copy(provenance, alloc_range(dest_offset, size), num_copies);
15931593

15941594
interp_ok(())
15951595
}

compiler/rustc_data_structures/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#![feature(sized_hierarchy)]
3535
#![feature(test)]
3636
#![feature(thread_id_value)]
37+
#![feature(trusted_len)]
3738
#![feature(type_alias_impl_trait)]
3839
#![feature(unwrap_infallible)]
3940
// tidy-alphabetical-end

compiler/rustc_data_structures/src/sorted_map.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::borrow::Borrow;
22
use std::cmp::Ordering;
33
use std::fmt::Debug;
4+
use std::iter::TrustedLen;
45
use std::mem;
56
use std::ops::{Bound, Index, IndexMut, RangeBounds};
67

@@ -215,32 +216,36 @@ impl<K: Ord, V> SortedMap<K, V> {
215216
/// It is up to the caller to make sure that the elements are sorted by key
216217
/// and that there are no duplicates.
217218
#[inline]
218-
pub fn insert_presorted(&mut self, elements: Vec<(K, V)>) {
219-
if elements.is_empty() {
219+
pub fn insert_presorted(
220+
&mut self,
221+
// We require `TrustedLen` to ensure that the `splice` below is actually efficient.
222+
mut elements: impl Iterator<Item = (K, V)> + DoubleEndedIterator + TrustedLen,
223+
) {
224+
let Some(first) = elements.next() else {
220225
return;
221-
}
222-
223-
debug_assert!(elements.array_windows().all(|[fst, snd]| fst.0 < snd.0));
226+
};
224227

225-
let start_index = self.lookup_index_for(&elements[0].0);
228+
let start_index = self.lookup_index_for(&first.0);
226229

227230
let elements = match start_index {
228231
Ok(index) => {
229-
let mut elements = elements.into_iter();
230-
self.data[index] = elements.next().unwrap();
231-
elements
232+
self.data[index] = first; // overwrite first element
233+
elements.chain(None) // insert the rest below
232234
}
233235
Err(index) => {
234-
if index == self.data.len() || elements.last().unwrap().0 < self.data[index].0 {
236+
let last = elements.next_back();
237+
if index == self.data.len()
238+
|| last.as_ref().is_none_or(|l| l.0 < self.data[index].0)
239+
{
235240
// We can copy the whole range without having to mix with
236241
// existing elements.
237-
self.data.splice(index..index, elements);
242+
self.data
243+
.splice(index..index, std::iter::once(first).chain(elements).chain(last));
238244
return;
239245
}
240246

241-
let mut elements = elements.into_iter();
242-
self.data.insert(index, elements.next().unwrap());
243-
elements
247+
self.data.insert(index, first);
248+
elements.chain(last) // insert the rest below
244249
}
245250
};
246251

compiler/rustc_data_structures/src/sorted_map/tests.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ fn test_insert_presorted_non_overlapping() {
171171
map.insert(2, 0);
172172
map.insert(8, 0);
173173

174-
map.insert_presorted(vec![(3, 0), (7, 0)]);
174+
map.insert_presorted(vec![(3, 0), (7, 0)].into_iter());
175175

176176
let expected = vec![2, 3, 7, 8];
177177
assert_eq!(keys(map), expected);
@@ -183,7 +183,7 @@ fn test_insert_presorted_first_elem_equal() {
183183
map.insert(2, 2);
184184
map.insert(8, 8);
185185

186-
map.insert_presorted(vec![(2, 0), (7, 7)]);
186+
map.insert_presorted(vec![(2, 0), (7, 7)].into_iter());
187187

188188
let expected = vec![(2, 0), (7, 7), (8, 8)];
189189
assert_eq!(elements(map), expected);
@@ -195,7 +195,7 @@ fn test_insert_presorted_last_elem_equal() {
195195
map.insert(2, 2);
196196
map.insert(8, 8);
197197

198-
map.insert_presorted(vec![(3, 3), (8, 0)]);
198+
map.insert_presorted(vec![(3, 3), (8, 0)].into_iter());
199199

200200
let expected = vec![(2, 2), (3, 3), (8, 0)];
201201
assert_eq!(elements(map), expected);
@@ -207,7 +207,7 @@ fn test_insert_presorted_shuffle() {
207207
map.insert(2, 2);
208208
map.insert(7, 7);
209209

210-
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)]);
210+
map.insert_presorted(vec![(1, 1), (3, 3), (8, 8)].into_iter());
211211

212212
let expected = vec![(1, 1), (2, 2), (3, 3), (7, 7), (8, 8)];
213213
assert_eq!(elements(map), expected);
@@ -219,7 +219,7 @@ fn test_insert_presorted_at_end() {
219219
map.insert(1, 1);
220220
map.insert(2, 2);
221221

222-
map.insert_presorted(vec![(3, 3), (8, 8)]);
222+
map.insert_presorted(vec![(3, 3), (8, 8)].into_iter());
223223

224224
let expected = vec![(1, 1), (2, 2), (3, 3), (8, 8)];
225225
assert_eq!(elements(map), expected);

compiler/rustc_middle/src/mir/interpret/allocation.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,13 @@ impl<Prov: Provenance, Extra, Bytes: AllocBytes> Allocation<Prov, Extra, Bytes>
849849
///
850850
/// This is dangerous to use as it can violate internal `Allocation` invariants!
851851
/// It only exists to support an efficient implementation of `mem_copy_repeatedly`.
852-
pub fn provenance_apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
853-
self.provenance.apply_copy(copy)
852+
pub fn provenance_apply_copy(
853+
&mut self,
854+
copy: ProvenanceCopy<Prov>,
855+
range: AllocRange,
856+
repeat: u64,
857+
) {
858+
self.provenance.apply_copy(copy, range, repeat)
854859
}
855860

856861
/// Applies a previously prepared copy of the init mask.

compiler/rustc_middle/src/mir/interpret/allocation/provenance_map.rs

Lines changed: 47 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -278,90 +278,78 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
278278

279279
/// A partial, owned list of provenance to transfer into another allocation.
280280
///
281-
/// Offsets are already adjusted to the destination allocation.
281+
/// Offsets are relative to the beginning of the copied range.
282282
pub struct ProvenanceCopy<Prov> {
283-
dest_ptrs: Option<Box<[(Size, Prov)]>>,
284-
dest_bytes: Option<Box<[(Size, (Prov, u8))]>>,
283+
ptrs: Box<[(Size, Prov)]>,
284+
bytes: Box<[(Size, (Prov, u8))]>,
285285
}
286286

287287
impl<Prov: Provenance> ProvenanceMap<Prov> {
288288
pub fn prepare_copy(
289289
&self,
290-
src: AllocRange,
291-
dest: Size,
292-
count: u64,
290+
range: AllocRange,
293291
cx: &impl HasDataLayout,
294292
) -> AllocResult<ProvenanceCopy<Prov>> {
295-
let shift_offset = move |idx, offset| {
296-
// compute offset for current repetition
297-
let dest_offset = dest + src.size * idx; // `Size` operations
298-
// shift offsets from source allocation to destination allocation
299-
(offset - src.start) + dest_offset // `Size` operations
300-
};
293+
let shift_offset = move |offset| offset - range.start;
301294
let ptr_size = cx.data_layout().pointer_size();
302295

303296
// # Pointer-sized provenances
304297
// Get the provenances that are entirely within this range.
305298
// (Different from `range_get_ptrs` which asks if they overlap the range.)
306299
// Only makes sense if we are copying at least one pointer worth of bytes.
307-
let mut dest_ptrs_box = None;
308-
if src.size >= ptr_size {
309-
let adjusted_end = Size::from_bytes(src.end().bytes() - (ptr_size.bytes() - 1));
310-
let ptrs = self.ptrs.range(src.start..adjusted_end);
311-
// If `count` is large, this is rather wasteful -- we are allocating a big array here, which
312-
// is mostly filled with redundant information since it's just N copies of the same `Prov`s
313-
// at slightly adjusted offsets. The reason we do this is so that in `mark_provenance_range`
314-
// we can use `insert_presorted`. That wouldn't work with an `Iterator` that just produces
315-
// the right sequence of provenance for all N copies.
316-
// Basically, this large array would have to be created anyway in the target allocation.
317-
let mut dest_ptrs = Vec::with_capacity(ptrs.len() * (count as usize));
318-
for i in 0..count {
319-
dest_ptrs
320-
.extend(ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
321-
}
322-
debug_assert_eq!(dest_ptrs.len(), dest_ptrs.capacity());
323-
dest_ptrs_box = Some(dest_ptrs.into_boxed_slice());
300+
let mut ptrs_box: Box<[_]> = Box::new([]);
301+
if range.size >= ptr_size {
302+
let adjusted_end = Size::from_bytes(range.end().bytes() - (ptr_size.bytes() - 1));
303+
let ptrs = self.ptrs.range(range.start..adjusted_end);
304+
ptrs_box = ptrs.iter().map(|&(offset, reloc)| (shift_offset(offset), reloc)).collect();
324305
};
325306

326307
// # Byte-sized provenances
327308
// This includes the existing bytewise provenance in the range, and ptr provenance
328309
// that overlaps with the begin/end of the range.
329-
let mut dest_bytes_box = None;
330-
let begin_overlap = self.range_ptrs_get(alloc_range(src.start, Size::ZERO), cx).first();
331-
let end_overlap = self.range_ptrs_get(alloc_range(src.end(), Size::ZERO), cx).first();
310+
let mut bytes_box: Box<[_]> = Box::new([]);
311+
let begin_overlap = self.range_ptrs_get(alloc_range(range.start, Size::ZERO), cx).first();
312+
let end_overlap = self.range_ptrs_get(alloc_range(range.end(), Size::ZERO), cx).first();
332313
// We only need to go here if there is some overlap or some bytewise provenance.
333314
if begin_overlap.is_some() || end_overlap.is_some() || self.bytes.is_some() {
334315
let mut bytes: Vec<(Size, (Prov, u8))> = Vec::new();
335316
// First, if there is a part of a pointer at the start, add that.
336317
if let Some(entry) = begin_overlap {
337318
trace!("start overlapping entry: {entry:?}");
338-
// For really small copies, make sure we don't run off the end of the `src` range.
339-
let entry_end = cmp::min(entry.0 + ptr_size, src.end());
340-
for offset in src.start..entry_end {
341-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
319+
// For really small copies, make sure we don't run off the end of the range.
320+
let entry_end = cmp::min(entry.0 + ptr_size, range.end());
321+
for offset in range.start..entry_end {
322+
bytes.push((shift_offset(offset), (entry.1, (offset - entry.0).bytes() as u8)));
342323
}
343324
} else {
344325
trace!("no start overlapping entry");
345326
}
346327

347328
// Then the main part, bytewise provenance from `self.bytes`.
348-
bytes.extend(self.range_bytes_get(src));
329+
bytes.extend(
330+
self.range_bytes_get(range)
331+
.iter()
332+
.map(|&(offset, reloc)| (shift_offset(offset), reloc)),
333+
);
349334

350335
// And finally possibly parts of a pointer at the end.
351336
if let Some(entry) = end_overlap {
352337
trace!("end overlapping entry: {entry:?}");
353-
// For really small copies, make sure we don't start before `src` does.
354-
let entry_start = cmp::max(entry.0, src.start);
355-
for offset in entry_start..src.end() {
338+
// For really small copies, make sure we don't start before `range` does.
339+
let entry_start = cmp::max(entry.0, range.start);
340+
for offset in entry_start..range.end() {
356341
if bytes.last().is_none_or(|bytes_entry| bytes_entry.0 < offset) {
357342
// The last entry, if it exists, has a lower offset than us, so we
358343
// can add it at the end and remain sorted.
359-
bytes.push((offset, (entry.1, (offset - entry.0).bytes() as u8)));
344+
bytes.push((
345+
shift_offset(offset),
346+
(entry.1, (offset - entry.0).bytes() as u8),
347+
));
360348
} else {
361349
// There already is an entry for this offset in there! This can happen when the
362350
// start and end range checks actually end up hitting the same pointer, so we
363351
// already added this in the "pointer at the start" part above.
364-
assert!(entry.0 <= src.start);
352+
assert!(entry.0 <= range.start);
365353
}
366354
}
367355
} else {
@@ -372,33 +360,34 @@ impl<Prov: Provenance> ProvenanceMap<Prov> {
372360
if !bytes.is_empty() && !Prov::OFFSET_IS_ADDR {
373361
// FIXME(#146291): We need to ensure that we don't mix different pointers with
374362
// the same provenance.
375-
return Err(AllocError::ReadPartialPointer(src.start));
363+
return Err(AllocError::ReadPartialPointer(range.start));
376364
}
377365

378366
// And again a buffer for the new list on the target side.
379-
let mut dest_bytes = Vec::with_capacity(bytes.len() * (count as usize));
380-
for i in 0..count {
381-
dest_bytes
382-
.extend(bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)));
383-
}
384-
debug_assert_eq!(dest_bytes.len(), dest_bytes.capacity());
385-
dest_bytes_box = Some(dest_bytes.into_boxed_slice());
367+
bytes_box = bytes.into_boxed_slice();
386368
}
387369

388-
Ok(ProvenanceCopy { dest_ptrs: dest_ptrs_box, dest_bytes: dest_bytes_box })
370+
Ok(ProvenanceCopy { ptrs: ptrs_box, bytes: bytes_box })
389371
}
390372

391373
/// Applies a provenance copy.
392374
/// The affected range, as defined in the parameters to `prepare_copy` is expected
393375
/// to be clear of provenance.
394-
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>) {
395-
if let Some(dest_ptrs) = copy.dest_ptrs {
396-
self.ptrs.insert_presorted(dest_ptrs.into());
376+
pub fn apply_copy(&mut self, copy: ProvenanceCopy<Prov>, range: AllocRange, repeat: u64) {
377+
let shift_offset = |idx: u64, offset: Size| offset + range.start + idx * range.size;
378+
if !copy.ptrs.is_empty() {
379+
for i in 0..repeat {
380+
self.ptrs.insert_presorted(
381+
copy.ptrs.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)),
382+
);
383+
}
397384
}
398-
if let Some(dest_bytes) = copy.dest_bytes
399-
&& !dest_bytes.is_empty()
400-
{
401-
self.bytes.get_or_insert_with(Box::default).insert_presorted(dest_bytes.into());
385+
if !copy.bytes.is_empty() {
386+
for i in 0..repeat {
387+
self.bytes.get_or_insert_with(Box::default).insert_presorted(
388+
copy.bytes.iter().map(|&(offset, reloc)| (shift_offset(i, offset), reloc)),
389+
);
390+
}
402391
}
403392
}
404393
}

0 commit comments

Comments
 (0)