Skip to content

Commit 365cdc0

Browse files
committed
[2024] linker/mem2reg: index SPIR-V blocks by their label IDs for O(1) lookup.
1 parent 6c187ce commit 365cdc0

File tree

3 files changed

+63
-46
lines changed

3 files changed

+63
-46
lines changed

crates/rustc_codegen_spirv/src/linker/dce.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
88
//! concept.
99
10-
use rspirv::dr::{Function, Instruction, Module, Operand};
10+
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
1111
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
12-
use rustc_data_structures::fx::FxIndexSet;
12+
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
13+
use std::hash::Hash;
1314

1415
pub fn dce(module: &mut Module) {
1516
let mut rooted = collect_roots(module);
@@ -137,11 +138,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
137138
}
138139
}
139140

140-
pub fn dce_phi(func: &mut Function) {
141+
pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
141142
let mut used = FxIndexSet::default();
142143
loop {
143144
let mut changed = false;
144-
for inst in func.all_inst_iter() {
145+
for inst in blocks.values().flat_map(|block| &block.instructions) {
145146
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
146147
for op in &inst.operands {
147148
if let Some(id) = op.id_ref_any() {
@@ -154,7 +155,7 @@ pub fn dce_phi(func: &mut Function) {
154155
break;
155156
}
156157
}
157-
for block in &mut func.blocks {
158+
for block in blocks.values_mut() {
158159
block
159160
.instructions
160161
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,31 @@ use super::simple_passes::outgoing_edges;
1313
use super::{apply_rewrite_rules, id};
1414
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1515
use rspirv::spirv::{Op, Word};
16-
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
16+
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
1717
use rustc_middle::bug;
1818
use std::collections::hash_map;
1919

20+
// HACK(eddyb) newtype instead of type alias to avoid mistakes.
21+
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
22+
struct LabelId(Word);
23+
2024
pub fn mem2reg(
2125
header: &mut ModuleHeader,
2226
types_global_values: &mut Vec<Instruction>,
2327
pointer_to_pointee: &FxHashMap<Word, Word>,
2428
constants: &FxHashMap<Word, u32>,
2529
func: &mut Function,
2630
) {
27-
let reachable = compute_reachable(&func.blocks);
28-
let preds = compute_preds(&func.blocks, &reachable);
31+
// HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32+
// it's made completely irrelevant by SPIR-T so only applies to legacy code.
33+
let mut blocks: FxIndexMap<_, _> = func
34+
.blocks
35+
.iter_mut()
36+
.map(|block| (LabelId(block.label_id().unwrap()), block))
37+
.collect();
38+
39+
let reachable = compute_reachable(&blocks);
40+
let preds = compute_preds(&blocks, &reachable);
2941
let idom = compute_idom(&preds, &reachable);
3042
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
3143
loop {
@@ -34,31 +46,27 @@ pub fn mem2reg(
3446
types_global_values,
3547
pointer_to_pointee,
3648
constants,
37-
&mut func.blocks,
49+
&mut blocks,
3850
&dominance_frontier,
3951
);
4052
if !changed {
4153
break;
4254
}
4355
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
44-
super::dce::dce_phi(func);
56+
super::dce::dce_phi(&mut blocks);
4557
}
4658
}
4759

48-
fn label_to_index(blocks: &[Block], id: Word) -> usize {
49-
blocks
50-
.iter()
51-
.position(|b| b.label_id().unwrap() == id)
52-
.unwrap()
53-
}
54-
55-
fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
56-
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
60+
fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
61+
fn recurse(blocks: &FxIndexMap<LabelId, &mut Block>, reachable: &mut [bool], block: usize) {
5762
if !reachable[block] {
5863
reachable[block] = true;
59-
for dest_id in outgoing_edges(&blocks[block]) {
60-
let dest_idx = label_to_index(blocks, dest_id);
61-
recurse(blocks, reachable, dest_idx);
64+
for dest_id in outgoing_edges(blocks[block]) {
65+
recurse(
66+
blocks,
67+
reachable,
68+
blocks.get_index_of(&LabelId(dest_id)).unwrap(),
69+
);
6270
}
6371
}
6472
}
@@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
6775
reachable
6876
}
6977

70-
fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
78+
fn compute_preds(
79+
blocks: &FxIndexMap<LabelId, &mut Block>,
80+
reachable_blocks: &[bool],
81+
) -> Vec<Vec<usize>> {
7182
let mut result = vec![vec![]; blocks.len()];
7283
// Do not count unreachable blocks as valid preds of blocks
7384
for (source_idx, source) in blocks
74-
.iter()
85+
.values()
7586
.enumerate()
7687
.filter(|&(b, _)| reachable_blocks[b])
7788
{
7889
for dest_id in outgoing_edges(source) {
79-
let dest_idx = label_to_index(blocks, dest_id);
80-
result[dest_idx].push(source_idx);
90+
result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx);
8191
}
8292
}
8393
result
@@ -161,7 +171,7 @@ fn insert_phis_all(
161171
types_global_values: &mut Vec<Instruction>,
162172
pointer_to_pointee: &FxHashMap<Word, Word>,
163173
constants: &FxHashMap<Word, u32>,
164-
blocks: &mut [Block],
174+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
165175
dominance_frontier: &[FxHashSet<usize>],
166176
) -> bool {
167177
let var_maps_and_types = blocks[0]
@@ -198,7 +208,11 @@ fn insert_phis_all(
198208
rewrite_rules: FxHashMap::default(),
199209
};
200210
renamer.rename(0, None);
201-
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
211+
// FIXME(eddyb) shouldn't this full rescan of the function be done once?
212+
apply_rewrite_rules(
213+
&renamer.rewrite_rules,
214+
blocks.values_mut().map(|block| &mut **block),
215+
);
202216
remove_nops(blocks);
203217
}
204218
remove_old_variables(blocks, &var_maps_and_types);
@@ -216,7 +230,7 @@ struct VarInfo {
216230
fn collect_access_chains(
217231
pointer_to_pointee: &FxHashMap<Word, Word>,
218232
constants: &FxHashMap<Word, u32>,
219-
blocks: &[Block],
233+
blocks: &FxIndexMap<LabelId, &mut Block>,
220234
base_var: Word,
221235
base_var_ty: Word,
222236
) -> Option<FxHashMap<Word, VarInfo>> {
@@ -249,7 +263,7 @@ fn collect_access_chains(
249263
// Loop in case a previous block references a later AccessChain
250264
loop {
251265
let mut changed = false;
252-
for inst in blocks.iter().flat_map(|b| &b.instructions) {
266+
for inst in blocks.values().flat_map(|b| &b.instructions) {
253267
for (index, op) in inst.operands.iter().enumerate() {
254268
if let Operand::IdRef(id) = op
255269
&& variables.contains_key(id)
@@ -303,10 +317,10 @@ fn collect_access_chains(
303317
// same var map (e.g. `s.x = s.y;`).
304318
fn split_copy_memory(
305319
header: &mut ModuleHeader,
306-
blocks: &mut [Block],
320+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
307321
var_map: &FxHashMap<Word, VarInfo>,
308322
) {
309-
for block in blocks {
323+
for block in blocks.values_mut() {
310324
let mut inst_index = 0;
311325
while inst_index < block.instructions.len() {
312326
let inst = &block.instructions[inst_index];
@@ -365,7 +379,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
365379
}
366380

367381
fn insert_phis(
368-
blocks: &[Block],
382+
blocks: &FxIndexMap<LabelId, &mut Block>,
369383
dominance_frontier: &[FxHashSet<usize>],
370384
var_map: &FxHashMap<Word, VarInfo>,
371385
) -> FxHashSet<usize> {
@@ -374,7 +388,7 @@ fn insert_phis(
374388
let mut ever_on_work_list = FxHashSet::default();
375389
let mut work_list = Vec::new();
376390
let mut blocks_with_phi = FxHashSet::default();
377-
for (block_idx, block) in blocks.iter().enumerate() {
391+
for (block_idx, block) in blocks.values().enumerate() {
378392
if has_store(block, var_map) {
379393
ever_on_work_list.insert(block_idx);
380394
work_list.push(block_idx);
@@ -419,10 +433,10 @@ fn top_stack_or_undef(
419433
}
420434
}
421435

422-
struct Renamer<'a> {
436+
struct Renamer<'a, 'b> {
423437
header: &'a mut ModuleHeader,
424438
types_global_values: &'a mut Vec<Instruction>,
425-
blocks: &'a mut [Block],
439+
blocks: &'a mut FxIndexMap<LabelId, &'b mut Block>,
426440
blocks_with_phi: FxHashSet<usize>,
427441
base_var_type: Word,
428442
var_map: &'a FxHashMap<Word, VarInfo>,
@@ -432,7 +446,7 @@ struct Renamer<'a> {
432446
rewrite_rules: FxHashMap<Word, Word>,
433447
}
434448

435-
impl Renamer<'_> {
449+
impl Renamer<'_, '_> {
436450
// Returns the phi definition.
437451
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
438452
let from_block_label = self.blocks[from_block].label_id().unwrap();
@@ -554,9 +568,8 @@ impl Renamer<'_> {
554568
}
555569
}
556570

557-
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
558-
// TODO: Don't do this find
559-
let dest_idx = label_to_index(self.blocks, dest_id);
571+
for dest_id in outgoing_edges(self.blocks[block]).collect::<Vec<_>>() {
572+
let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap();
560573
self.rename(dest_idx, Some(block));
561574
}
562575

@@ -566,16 +579,16 @@ impl Renamer<'_> {
566579
}
567580
}
568581

569-
fn remove_nops(blocks: &mut [Block]) {
570-
for block in blocks {
582+
fn remove_nops(blocks: &mut FxIndexMap<LabelId, &mut Block>) {
583+
for block in blocks.values_mut() {
571584
block
572585
.instructions
573586
.retain(|inst| inst.class.opcode != Op::Nop);
574587
}
575588
}
576589

577590
fn remove_old_variables(
578-
blocks: &mut [Block],
591+
blocks: &mut FxIndexMap<LabelId, &mut Block>,
579592
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
580593
) {
581594
blocks[0].instructions.retain(|inst| {
@@ -586,7 +599,7 @@ fn remove_old_variables(
586599
.all(|(var_map, _)| !var_map.contains_key(&result_id))
587600
}
588601
});
589-
for block in blocks {
602+
for block in blocks.values_mut() {
590603
block.instructions.retain(|inst| {
591604
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
592605
|| inst.operands.iter().all(|op| {

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ fn id(header: &mut ModuleHeader) -> Word {
8585
result
8686
}
8787

88-
fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
88+
fn apply_rewrite_rules<'a>(
89+
rewrite_rules: &FxHashMap<Word, Word>,
90+
blocks: impl IntoIterator<Item = &'a mut Block>,
91+
) {
8992
let all_ids_mut = blocks
90-
.iter_mut()
93+
.into_iter()
9194
.flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
9295
.flat_map(|inst| {
9396
inst.result_id

0 commit comments

Comments
 (0)