Skip to content

Commit 55f5986

Browse files
eddybFirestar99
authored andcommitted
WIP: (TODO: finish bottom-up cleanups) bottom-up inlining
1 parent a3b0068 commit 55f5986

File tree

2 files changed

+107
-82
lines changed

2 files changed

+107
-82
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 96 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ use rustc_session::Session;
1717
use smallvec::SmallVec;
1818
use std::mem;
1919

20-
type FunctionMap = FxHashMap<Word, Function>;
21-
2220
// FIXME(eddyb) this is a bit silly, but this keeps being repeated everywhere.
2321
fn next_id(header: &mut ModuleHeader) -> Word {
2422
let result = header.bound;
@@ -30,6 +28,9 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3028
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
3129
deny_recursion_in_module(sess, module)?;
3230

31+
// Compute the call-graph that will drive (inside-out, aka bottom-up) inlining.
32+
let (call_graph, func_id_to_idx) = CallGraph::collect_with_func_id_to_idx(module);
33+
3334
let custom_ext_inst_set_import = module
3435
.ext_inst_imports
3536
.iter()
@@ -39,62 +40,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3940
})
4041
.map(|inst| inst.result_id.unwrap());
4142

42-
// HACK(eddyb) compute the set of functions that may `Abort` *transitively*,
43-
// which is only needed because of how we inline (sometimes it's outside-in,
44-
// aka top-down, instead of always being inside-out, aka bottom-up).
45-
//
46-
// (inlining is needed in the first place because our custom `Abort`
47-
// instructions get lowered to a simple `OpReturn` in entry-points, but
48-
// that requires that they get inlined all the way up to the entry-points)
49-
let functions_that_may_abort = custom_ext_inst_set_import
50-
.map(|custom_ext_inst_set_import| {
51-
let mut may_abort_by_id = FxHashSet::default();
52-
53-
// FIXME(eddyb) use this `CallGraph` abstraction more during inlining.
54-
let call_graph = CallGraph::collect(module);
55-
for func_idx in call_graph.post_order() {
56-
let func_id = module.functions[func_idx].def_id().unwrap();
57-
58-
let any_callee_may_abort = call_graph.callees[func_idx].iter().any(|&callee_idx| {
59-
may_abort_by_id.contains(&module.functions[callee_idx].def_id().unwrap())
60-
});
61-
if any_callee_may_abort {
62-
may_abort_by_id.insert(func_id);
63-
continue;
64-
}
65-
66-
let may_abort_directly = module.functions[func_idx].blocks.iter().any(|block| {
67-
match &block.instructions[..] {
68-
[.., last_normal_inst, terminator_inst]
69-
if last_normal_inst.class.opcode == Op::ExtInst
70-
&& last_normal_inst.operands[0].unwrap_id_ref()
71-
== custom_ext_inst_set_import
72-
&& CustomOp::decode_from_ext_inst(last_normal_inst)
73-
== CustomOp::Abort =>
74-
{
75-
assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
76-
true
77-
}
78-
79-
_ => false,
80-
}
81-
});
82-
if may_abort_directly {
83-
may_abort_by_id.insert(func_id);
84-
}
85-
}
86-
87-
may_abort_by_id
88-
})
89-
.unwrap_or_default();
90-
91-
let functions = module
92-
.functions
93-
.iter()
94-
.map(|f| (f.def_id().unwrap(), f.clone()))
95-
.collect();
96-
let legal_globals = LegalGlobal::gather_from_module(module);
97-
43+
/*
9844
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
9945
// inlines in functions that will get inlined)
10046
let mut dropped_ids = FxHashSet::default();
@@ -123,6 +69,9 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
12369
));
12470
}
12571
}
72+
*/
73+
74+
let legal_globals = LegalGlobal::gather_from_module(module);
12675

12776
let header = module.header.as_mut().unwrap();
12877
// FIXME(eddyb) clippy false positive (separate `map` required for borrowck).
@@ -154,6 +103,8 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
154103
id
155104
}),
156105

106+
func_id_to_idx,
107+
157108
id_to_name: module
158109
.debug_names
159110
.iter()
@@ -173,22 +124,61 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
173124
annotations: &mut module.annotations,
174125
types_global_values: &mut module.types_global_values,
175126

176-
functions: &functions,
177-
legal_globals: &legal_globals,
178-
functions_that_may_abort: &functions_that_may_abort,
127+
legal_globals,
128+
129+
// NOTE(eddyb) this is needed because our custom `Abort` instructions get
130+
// lowered to a simple `OpReturn` in entry-points, but that requires that
131+
// they get inlined all the way up to the entry-points in the first place.
132+
functions_that_may_abort: module
133+
.functions
134+
.iter()
135+
.filter_map(|func| {
136+
let custom_ext_inst_set_import = custom_ext_inst_set_import?;
137+
func.blocks
138+
.iter()
139+
.any(|block| match &block.instructions[..] {
140+
[.., last_normal_inst, terminator_inst]
141+
if last_normal_inst.class.opcode == Op::ExtInst
142+
&& last_normal_inst.operands[0].unwrap_id_ref()
143+
== custom_ext_inst_set_import
144+
&& CustomOp::decode_from_ext_inst(last_normal_inst)
145+
== CustomOp::Abort =>
146+
{
147+
assert_eq!(terminator_inst.class.opcode, Op::Unreachable);
148+
true
149+
}
150+
151+
_ => false,
152+
})
153+
.then_some(func.def_id().unwrap())
154+
})
155+
.collect(),
179156
};
180-
for function in &mut module.functions {
181-
inliner.inline_fn(function);
182-
fuse_trivial_branches(function);
157+
158+
let mut functions: Vec<_> = mem::take(&mut module.functions)
159+
.into_iter()
160+
.map(Ok)
161+
.collect();
162+
163+
// Inline functions in post-order (aka inside-out aka bottom-out) - that is,
164+
// callees are processed before their callers, to avoid duplicating work.
165+
for func_idx in call_graph.post_order() {
166+
let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap();
167+
inliner.inline_fn(&mut function, &functions);
168+
fuse_trivial_branches(&mut function);
169+
functions[func_idx] = Ok(function);
183170
}
184171

172+
module.functions = functions.into_iter().map(|func| func.unwrap()).collect();
173+
174+
/*
185175
// Drop OpName etc. for inlined functions
186176
module.debug_names.retain(|inst| {
187177
!inst.operands.iter().any(|op| {
188178
op.id_ref_any()
189179
.map_or(false, |id| dropped_ids.contains(&id))
190180
})
191-
});
181+
});*/
192182

193183
Ok(())
194184
}
@@ -456,19 +446,27 @@ fn should_inline(
456446
Ok(callee_control.contains(FunctionControl::INLINE))
457447
}
458448

449+
/// Helper error type for `Inliner`'s `functions` field, indicating a `Function`
450+
/// was taken out of its slot because it's being inlined.
451+
#[derive(Debug)]
452+
struct FuncIsBeingInlined;
453+
459454
// Steps:
460455
// Move OpVariable decls
461456
// Rewrite return
462457
// Renumber IDs
463458
// Insert blocks
464459

465-
struct Inliner<'m, 'map> {
460+
struct Inliner<'m> {
466461
/// ID of `OpExtInstImport` for our custom "extended instruction set"
467462
/// (see `crate::custom_insts` for more details).
468463
custom_ext_inst_set_import: Word,
469464

470465
op_type_void_id: Word,
471466

467+
/// Map from each function's ID to its index in `functions`.
468+
func_id_to_idx: FxHashMap<Word, usize>,
469+
472470
/// Pre-collected `OpName`s, that can be used to find any function's name
473471
/// during inlining (to be able to generate debuginfo that uses names).
474472
id_to_name: FxHashMap<Word, &'m str>,
@@ -485,13 +483,12 @@ struct Inliner<'m, 'map> {
485483
annotations: &'m mut Vec<Instruction>,
486484
types_global_values: &'m mut Vec<Instruction>,
487485

488-
functions: &'map FunctionMap,
489-
legal_globals: &'map FxHashMap<Word, LegalGlobal>,
490-
functions_that_may_abort: &'map FxHashSet<Word>,
486+
legal_globals: FxHashMap<Word, LegalGlobal>,
487+
functions_that_may_abort: FxHashSet<Word>,
491488
// rewrite_rules: FxHashMap<Word, Word>,
492489
}
493490

494-
impl Inliner<'_, '_> {
491+
impl Inliner<'_> {
495492
fn id(&mut self) -> Word {
496493
next_id(self.header)
497494
}
@@ -536,19 +533,29 @@ impl Inliner<'_, '_> {
536533
inst_id
537534
}
538535

539-
fn inline_fn(&mut self, function: &mut Function) {
536+
fn inline_fn(
537+
&mut self,
538+
function: &mut Function,
539+
functions: &[Result<Function, FuncIsBeingInlined>],
540+
) {
540541
let mut block_idx = 0;
541542
while block_idx < function.blocks.len() {
542543
// If we successfully inlined a block, then repeat processing on the same block, in
543544
// case the newly inlined block has more inlined calls.
544545
// TODO: This is quadratic
545-
if !self.inline_block(function, block_idx) {
546+
if !self.inline_block(function, block_idx, functions) {
547+
// TODO(eddyb) skip past the inlined callee without rescanning it.
546548
block_idx += 1;
547549
}
548550
}
549551
}
550552

551-
fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool {
553+
fn inline_block(
554+
&mut self,
555+
caller: &mut Function,
556+
block_idx: usize,
557+
functions: &[Result<Function, FuncIsBeingInlined>],
558+
) -> bool {
552559
// Find the first inlined OpFunctionCall
553560
let call = caller.blocks[block_idx]
554561
.instructions
@@ -559,8 +566,8 @@ impl Inliner<'_, '_> {
559566
(
560567
index,
561568
inst,
562-
self.functions
563-
.get(&inst.operands[0].id_ref_any().unwrap())
569+
functions[self.func_id_to_idx[&inst.operands[0].id_ref_any().unwrap()]]
570+
.as_ref()
564571
.unwrap(),
565572
)
566573
})
@@ -570,8 +577,8 @@ impl Inliner<'_, '_> {
570577
call_inst: inst,
571578
};
572579
match should_inline(
573-
self.legal_globals,
574-
self.functions_that_may_abort,
580+
&self.legal_globals,
581+
&self.functions_that_may_abort,
575582
f,
576583
Some(call_site),
577584
) {
@@ -583,6 +590,16 @@ impl Inliner<'_, '_> {
583590
None => return false,
584591
Some(call) => call,
585592
};
593+
594+
// Propagate "may abort" from callee to caller (i.e. as aborts get inlined).
595+
if self
596+
.functions_that_may_abort
597+
.contains(&callee.def_id().unwrap())
598+
{
599+
self.functions_that_may_abort
600+
.insert(caller.def_id().unwrap());
601+
}
602+
586603
let call_result_type = {
587604
let ty = call_inst.result_type.unwrap();
588605
if ty == self.op_type_void_id {
@@ -594,6 +611,7 @@ impl Inliner<'_, '_> {
594611
let call_result_id = call_inst.result_id.unwrap();
595612

596613
// Get the debuginfo instructions that apply to the call.
614+
// TODO(eddyb) only one instruction should be necessary here w/ bottom-up.
597615
let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
598616
let call_debug_insts = caller.blocks[block_idx].instructions[..call_index]
599617
.iter()
@@ -868,6 +886,7 @@ impl Inliner<'_, '_> {
868886
..
869887
} = *self;
870888

889+
// TODO(eddyb) kill this as it shouldn't be needed for bottom-up inline.
871890
// HACK(eddyb) this is terrible, but we have to deal with it because of
872891
// how this inliner is outside-in, instead of inside-out, meaning that
873892
// context builds up "outside" of the callee blocks, inside the caller.

crates/rustc_codegen_spirv/src/linker/ipo.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use indexmap::IndexSet;
66
use rspirv::dr::Module;
7-
use rspirv::spirv::Op;
7+
use rspirv::spirv::{Op, Word};
88
use rustc_data_structures::fx::FxHashMap;
99

1010
// FIXME(eddyb) use newtyped indices and `IndexVec`.
@@ -19,6 +19,9 @@ pub struct CallGraph {
1919

2020
impl CallGraph {
2121
pub fn collect(module: &Module) -> Self {
22+
Self::collect_with_func_id_to_idx(module).0
23+
}
24+
pub fn collect_with_func_id_to_idx(module: &Module) -> (Self, FxHashMap<Word, FuncIdx>) {
2225
let func_id_to_idx: FxHashMap<_, _> = module
2326
.functions
2427
.iter()
@@ -51,10 +54,13 @@ impl CallGraph {
5154
.collect()
5255
})
5356
.collect();
54-
Self {
55-
entry_points,
56-
callees,
57-
}
57+
(
58+
Self {
59+
entry_points,
60+
callees,
61+
},
62+
func_id_to_idx,
63+
)
5864
}
5965

6066
/// Order functions using a post-order traversal, i.e. callees before callers.

0 commit comments

Comments
 (0)