Skip to content

Commit 0e0a60a

Browse files
authored
Define an RAII helper for generic take-and-replace borrow splitting (#10548)
* Define an RAII helper for generic take-and-replace borrow splitting Follow up to #10524 (comment) * &mut T
1 parent 3da7fc8 commit 0e0a60a

File tree

3 files changed

+120
-55
lines changed

3 files changed

+120
-55
lines changed

cranelift/codegen/src/egraph.rs

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::opts::generated_code::SkeletonInstSimplification;
1616
use crate::opts::IsleContext;
1717
use crate::scoped_hash_map::{Entry as ScopedEntry, ScopedHashMap};
1818
use crate::settings::Flags;
19+
use crate::take_and_replace::TakeAndReplace;
1920
use crate::trace;
2021
use alloc::vec::Vec;
2122
use core::cmp::Ordering;
@@ -298,7 +299,8 @@ where
298299
// A pure node always has exactly one result.
299300
let orig_value = self.func.dfg.first_result(inst);
300301

301-
let mut optimized_values = std::mem::take(&mut self.optimized_values);
302+
let mut guard = TakeAndReplace::new(self, |x| &mut x.optimized_values);
303+
let (ctx, optimized_values) = guard.get();
302304

303305
// Limit rewrite depth. When we apply optimization rules, they
304306
// may create new nodes (values) and those are, recursively,
@@ -310,28 +312,28 @@ where
310312
// infinite or problematic recursion, we bound the rewrite
311313
// depth to a small constant here.
312314
const REWRITE_LIMIT: usize = 5;
313-
if self.rewrite_depth > REWRITE_LIMIT {
314-
self.stats.rewrite_depth_limit += 1;
315+
if ctx.rewrite_depth > REWRITE_LIMIT {
316+
ctx.stats.rewrite_depth_limit += 1;
315317
return orig_value;
316318
}
317-
self.rewrite_depth += 1;
318-
trace!("Incrementing rewrite depth; now {}", self.rewrite_depth);
319+
ctx.rewrite_depth += 1;
320+
trace!("Incrementing rewrite depth; now {}", ctx.rewrite_depth);
319321

320322
// Invoke the ISLE toplevel constructor, getting all new
321323
// values produced as equivalents to this value.
322324
trace!("Calling into ISLE with original value {}", orig_value);
323-
self.stats.rewrite_rule_invoked += 1;
325+
ctx.stats.rewrite_rule_invoked += 1;
324326
debug_assert!(optimized_values.is_empty());
325327
crate::opts::generated_code::constructor_simplify(
326-
&mut IsleContext { ctx: self },
328+
&mut IsleContext { ctx },
327329
orig_value,
328-
&mut optimized_values,
330+
optimized_values,
329331
);
330332

331-
self.stats.rewrite_rule_results += optimized_values.len() as u64;
333+
ctx.stats.rewrite_rule_results += optimized_values.len() as u64;
332334

333335
// It's not supposed to matter what order `simplify` returns values in.
334-
self.ctrl_plane.shuffle(&mut optimized_values);
336+
ctx.ctrl_plane.shuffle(optimized_values);
335337

336338
let num_matches = optimized_values.len();
337339
if num_matches > MATCHES_LIMIT {
@@ -351,10 +353,10 @@ where
351353
// all returned values.
352354
let result_value = if let Some(&subsuming_value) = optimized_values
353355
.iter()
354-
.find(|&value| self.subsume_values.contains(value))
356+
.find(|&value| ctx.subsume_values.contains(value))
355357
{
356358
optimized_values.clear();
357-
self.stats.pure_inst_subsume += 1;
359+
ctx.stats.pure_inst_subsume += 1;
358360
subsuming_value
359361
} else {
360362
let mut union_value = orig_value;
@@ -366,29 +368,27 @@ where
366368
);
367369
if optimized_value == orig_value {
368370
trace!(" -> same as orig value; skipping");
369-
self.stats.pure_inst_rewrite_to_self += 1;
371+
ctx.stats.pure_inst_rewrite_to_self += 1;
370372
continue;
371373
}
372374
let old_union_value = union_value;
373-
union_value = self.func.dfg.union(old_union_value, optimized_value);
374-
self.stats.union += 1;
375+
union_value = ctx.func.dfg.union(old_union_value, optimized_value);
376+
ctx.stats.union += 1;
375377
trace!(" -> union: now {}", union_value);
376-
self.func.dfg.merge_facts(old_union_value, optimized_value);
377-
self.available_block[union_value] =
378-
self.merge_availability(old_union_value, optimized_value);
378+
ctx.func.dfg.merge_facts(old_union_value, optimized_value);
379+
ctx.available_block[union_value] =
380+
ctx.merge_availability(old_union_value, optimized_value);
379381
}
380382
union_value
381383
};
382384

383-
self.rewrite_depth -= 1;
384-
trace!("Decrementing rewrite depth; now {}", self.rewrite_depth);
385-
if self.rewrite_depth == 0 {
386-
self.subsume_values.clear();
385+
ctx.rewrite_depth -= 1;
386+
trace!("Decrementing rewrite depth; now {}", ctx.rewrite_depth);
387+
if ctx.rewrite_depth == 0 {
388+
ctx.subsume_values.clear();
387389
}
388390

389-
debug_assert!(self.optimized_values.is_empty());
390-
self.optimized_values = optimized_values;
391-
391+
debug_assert!(ctx.optimized_values.is_empty());
392392
result_value
393393
}
394394

@@ -564,36 +564,8 @@ where
564564
return None;
565565
}
566566

567-
/// A small RAII helper for temporarily taking out our `optimized_insts`
568-
/// vec and then replacing it upon drop.
569-
struct WithOptimizedInsts<'a, 'opt, 'analysis> {
570-
ctx: &'a mut OptimizeCtx<'opt, 'analysis>,
571-
optimized_insts: SmallVec<[SkeletonInstSimplification; MATCHES_LIMIT]>,
572-
}
573-
574-
impl Drop for WithOptimizedInsts<'_, '_, '_> {
575-
fn drop(&mut self) {
576-
self.optimized_insts.clear();
577-
self.ctx.optimized_insts = std::mem::take(&mut self.optimized_insts);
578-
}
579-
}
580-
581-
impl<'a, 'b, 'c> WithOptimizedInsts<'a, 'b, 'c> {
582-
fn new(ctx: &'a mut OptimizeCtx<'b, 'c>) -> Self {
583-
let optimized_insts = std::mem::take(&mut ctx.optimized_insts);
584-
debug_assert!(optimized_insts.is_empty());
585-
WithOptimizedInsts {
586-
ctx,
587-
optimized_insts,
588-
}
589-
}
590-
}
591-
592-
let mut guard = WithOptimizedInsts::new(self);
593-
let WithOptimizedInsts {
594-
ctx,
595-
optimized_insts,
596-
} = &mut guard;
567+
let mut guard = TakeAndReplace::new(self, |x| &mut x.optimized_insts);
568+
let (ctx, optimized_insts) = guard.get();
597569

598570
crate::opts::generated_code::constructor_simplify_skeleton(
599571
&mut IsleContext { ctx },

cranelift/codegen/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ mod ranges;
8282
mod remove_constant_phis;
8383
mod result;
8484
mod scoped_hash_map;
85+
mod take_and_replace;
8586
mod unreachable_code;
8687
mod value_label;
8788

8889
#[cfg(feature = "souper-harvest")]
8990
mod souper_harvest;
9091

9192
pub use crate::result::{CodegenError, CodegenResult, CompileError};
93+
pub use crate::take_and_replace::TakeAndReplace;
9294

9395
#[cfg(feature = "incremental-cache")]
9496
pub mod incremental_cache;
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//! Helper for temporarily taking values out and then putting them back in.
2+
3+
/// An RAII type to temporarily take a `U` out of a `T` and then put it back
4+
/// again on drop.
5+
///
6+
/// This allows you to split borrows, if necessary, to satisfy the borrow
7+
/// checker.
8+
///
9+
/// The `F` type parameter must project from the container type `T` to its `U`
10+
/// that we want to temporarily take out of it.
11+
///
12+
/// # Example
13+
///
14+
/// ```
15+
/// use cranelift_codegen::TakeAndReplace;
16+
///
17+
/// #[derive(Default)]
18+
/// struct BigContextStruct {
19+
/// items: Vec<u32>,
20+
/// count: usize,
21+
/// }
22+
///
23+
/// impl BigContextStruct {
24+
/// fn handle_item(&mut self, item: u32) {
25+
/// self.count += 1;
26+
/// println!("Handled {item}!");
27+
/// }
28+
/// }
29+
///
30+
/// let mut ctx = BigContextStruct::default();
31+
/// ctx.items.extend([42, 1337, 1312]);
32+
///
33+
/// {
34+
/// // Temporarily take `self.items` out of `ctx`.
35+
/// let mut guard = TakeAndReplace::new(&mut ctx, |ctx| &mut ctx.items);
36+
/// let (ctx, items) = guard.get();
37+
///
38+
/// // Now we can both borrow/iterate/mutate `items` and call `&mut self` helper
39+
/// // methods on `ctx`. This would not otherwise be possible if we didn't split
40+
/// // the borrows, since Rust's borrow checker doesn't see through methods and
41+
/// // know that `handle_item` doesn't use `self.items`.
42+
/// for item in items.drain(..) {
43+
/// ctx.handle_item(item);
44+
/// }
45+
/// }
46+
///
47+
/// // When `guard` is dropped, `items` is replaced in `ctx`, allowing us to
48+
/// // reuse its capacity and avoid future allocations. ```
49+
/// assert!(ctx.items.capacity() >= 3);
50+
/// ```
51+
pub struct TakeAndReplace<'a, T, U, F>
52+
where
53+
F: Fn(&mut T) -> &mut U,
54+
U: Default,
55+
{
56+
container: &'a mut T,
57+
value: U,
58+
proj: F,
59+
}
60+
61+
impl<'a, T, U, F> Drop for TakeAndReplace<'a, T, U, F>
62+
where
63+
F: Fn(&mut T) -> &mut U,
64+
U: Default,
65+
{
66+
fn drop(&mut self) {
67+
*(self.proj)(self.container) = std::mem::take(&mut self.value);
68+
}
69+
}
70+
71+
impl<'a, T, U, F> TakeAndReplace<'a, T, U, F>
72+
where
73+
F: Fn(&mut T) -> &mut U,
74+
U: Default,
75+
{
76+
/// Create a new `TakeAndReplace` that temporarily takes out
77+
/// `proj(container)`.
78+
pub fn new(mut container: &'a mut T, proj: F) -> Self {
79+
let value = std::mem::take(proj(&mut container));
80+
TakeAndReplace {
81+
container,
82+
value,
83+
proj,
84+
}
85+
}
86+
87+
/// Get the underlying container and taken-out value.
88+
pub fn get(&mut self) -> (&mut T, &mut U) {
89+
(&mut *self.container, &mut self.value)
90+
}
91+
}

0 commit comments

Comments
 (0)