Skip to content

Commit 1dfe78c

Browse files
committed
Revert "WIP: mem2reg speedup"
This reverts commit efbf694.
1 parent 93afbf2 commit 1dfe78c

File tree

3 files changed

+46
-63
lines changed

3 files changed

+46
-63
lines changed

crates/rustc_codegen_spirv/src/linker/dce.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
88
//! concept.
99
10-
use rspirv::dr::{Block, Function, Instruction, Module, Operand};
10+
use rspirv::dr::{Function, Instruction, Module, Operand};
1111
use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word};
12-
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
13-
use std::hash::Hash;
12+
use rustc_data_structures::fx::FxIndexSet;
1413

1514
pub fn dce(module: &mut Module) {
1615
let mut rooted = collect_roots(module);
@@ -138,11 +137,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet<Word>) {
138137
}
139138
}
140139

141-
pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
140+
pub fn dce_phi(func: &mut Function) {
142141
let mut used = FxIndexSet::default();
143142
loop {
144143
let mut changed = false;
145-
for inst in blocks.values().flat_map(|block| &block.instructions) {
144+
for inst in func.all_inst_iter() {
146145
if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) {
147146
for op in &inst.operands {
148147
if let Some(id) = op.id_ref_any() {
@@ -155,7 +154,7 @@ pub fn dce_phi(blocks: &mut FxIndexMap<impl Eq + Hash, &mut Block>) {
155154
break;
156155
}
157156
}
158-
for block in blocks.values_mut() {
157+
for block in &mut func.blocks {
159158
block
160159
.instructions
161160
.retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()));

crates/rustc_codegen_spirv/src/linker/mem2reg.rs

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,19 @@ use super::simple_passes::outgoing_edges;
1313
use super::{apply_rewrite_rules, id};
1414
use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand};
1515
use rspirv::spirv::{Op, Word};
16-
use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap};
16+
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1717
use rustc_middle::bug;
1818
use std::collections::hash_map;
1919

20-
// HACK(eddyb) newtype instead of type alias to avoid mistakes.
21-
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
22-
struct LabelId(Word);
23-
2420
pub fn mem2reg(
2521
header: &mut ModuleHeader,
2622
types_global_values: &mut Vec<Instruction>,
2723
pointer_to_pointee: &FxHashMap<Word, Word>,
2824
constants: &FxHashMap<Word, u32>,
2925
func: &mut Function,
3026
) {
31-
// HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but
32-
// it's made completely irrelevant by SPIR-T so only applies to legacy code.
33-
let mut blocks: FxIndexMap<_, _> = func
34-
.blocks
35-
.iter_mut()
36-
.map(|block| (LabelId(block.label_id().unwrap()), block))
37-
.collect();
38-
39-
let reachable = compute_reachable(&blocks);
40-
let preds = compute_preds(&blocks, &reachable);
27+
let reachable = compute_reachable(&func.blocks);
28+
let preds = compute_preds(&func.blocks, &reachable);
4129
let idom = compute_idom(&preds, &reachable);
4230
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
4331
loop {
@@ -46,27 +34,31 @@ pub fn mem2reg(
4634
types_global_values,
4735
pointer_to_pointee,
4836
constants,
49-
&mut blocks,
37+
&mut func.blocks,
5038
&dominance_frontier,
5139
);
5240
if !changed {
5341
break;
5442
}
5543
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
56-
super::dce::dce_phi(&mut blocks);
44+
super::dce::dce_phi(func);
5745
}
5846
}
5947

60-
fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
61-
fn recurse(blocks: &FxIndexMap<LabelId, &mut Block>, reachable: &mut [bool], block: usize) {
48+
fn label_to_index(blocks: &[Block], id: Word) -> usize {
49+
blocks
50+
.iter()
51+
.position(|b| b.label_id().unwrap() == id)
52+
.unwrap()
53+
}
54+
55+
fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
56+
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
6257
if !reachable[block] {
6358
reachable[block] = true;
64-
for dest_id in outgoing_edges(blocks[block]) {
65-
recurse(
66-
blocks,
67-
reachable,
68-
blocks.get_index_of(&LabelId(dest_id)).unwrap(),
69-
);
59+
for dest_id in outgoing_edges(&blocks[block]) {
60+
let dest_idx = label_to_index(blocks, dest_id);
61+
recurse(blocks, reachable, dest_idx);
7062
}
7163
}
7264
}
@@ -75,19 +67,17 @@ fn compute_reachable(blocks: &FxIndexMap<LabelId, &mut Block>) -> Vec<bool> {
7567
reachable
7668
}
7769

78-
fn compute_preds(
79-
blocks: &FxIndexMap<LabelId, &mut Block>,
80-
reachable_blocks: &[bool],
81-
) -> Vec<Vec<usize>> {
70+
fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
8271
let mut result = vec![vec![]; blocks.len()];
8372
// Do not count unreachable blocks as valid preds of blocks
8473
for (source_idx, source) in blocks
85-
.values()
74+
.iter()
8675
.enumerate()
8776
.filter(|&(b, _)| reachable_blocks[b])
8877
{
8978
for dest_id in outgoing_edges(source) {
90-
result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx);
79+
let dest_idx = label_to_index(blocks, dest_id);
80+
result[dest_idx].push(source_idx);
9181
}
9282
}
9383
result
@@ -171,7 +161,7 @@ fn insert_phis_all(
171161
types_global_values: &mut Vec<Instruction>,
172162
pointer_to_pointee: &FxHashMap<Word, Word>,
173163
constants: &FxHashMap<Word, u32>,
174-
blocks: &mut FxIndexMap<LabelId, &mut Block>,
164+
blocks: &mut [Block],
175165
dominance_frontier: &[FxHashSet<usize>],
176166
) -> bool {
177167
let var_maps_and_types = blocks[0]
@@ -208,11 +198,7 @@ fn insert_phis_all(
208198
rewrite_rules: FxHashMap::default(),
209199
};
210200
renamer.rename(0, None);
211-
// FIXME(eddyb) shouldn't this full rescan of the function be done once?
212-
apply_rewrite_rules(
213-
&renamer.rewrite_rules,
214-
blocks.values_mut().map(|block| &mut **block),
215-
);
201+
apply_rewrite_rules(&renamer.rewrite_rules, blocks);
216202
remove_nops(blocks);
217203
}
218204
remove_old_variables(blocks, &var_maps_and_types);
@@ -230,7 +216,7 @@ struct VarInfo {
230216
fn collect_access_chains(
231217
pointer_to_pointee: &FxHashMap<Word, Word>,
232218
constants: &FxHashMap<Word, u32>,
233-
blocks: &FxIndexMap<LabelId, &mut Block>,
219+
blocks: &[Block],
234220
base_var: Word,
235221
base_var_ty: Word,
236222
) -> Option<FxHashMap<Word, VarInfo>> {
@@ -263,7 +249,7 @@ fn collect_access_chains(
263249
// Loop in case a previous block references a later AccessChain
264250
loop {
265251
let mut changed = false;
266-
for inst in blocks.values().flat_map(|b| &b.instructions) {
252+
for inst in blocks.iter().flat_map(|b| &b.instructions) {
267253
for (index, op) in inst.operands.iter().enumerate() {
268254
if let Operand::IdRef(id) = op
269255
&& variables.contains_key(id)
@@ -317,10 +303,10 @@ fn collect_access_chains(
317303
// same var map (e.g. `s.x = s.y;`).
318304
fn split_copy_memory(
319305
header: &mut ModuleHeader,
320-
blocks: &mut FxIndexMap<LabelId, &mut Block>,
306+
blocks: &mut [Block],
321307
var_map: &FxHashMap<Word, VarInfo>,
322308
) {
323-
for block in blocks.values_mut() {
309+
for block in blocks {
324310
let mut inst_index = 0;
325311
while inst_index < block.instructions.len() {
326312
let inst = &block.instructions[inst_index];
@@ -379,7 +365,7 @@ fn has_store(block: &Block, var_map: &FxHashMap<Word, VarInfo>) -> bool {
379365
}
380366

381367
fn insert_phis(
382-
blocks: &FxIndexMap<LabelId, &mut Block>,
368+
blocks: &[Block],
383369
dominance_frontier: &[FxHashSet<usize>],
384370
var_map: &FxHashMap<Word, VarInfo>,
385371
) -> FxHashSet<usize> {
@@ -388,7 +374,7 @@ fn insert_phis(
388374
let mut ever_on_work_list = FxHashSet::default();
389375
let mut work_list = Vec::new();
390376
let mut blocks_with_phi = FxHashSet::default();
391-
for (block_idx, block) in blocks.values().enumerate() {
377+
for (block_idx, block) in blocks.iter().enumerate() {
392378
if has_store(block, var_map) {
393379
ever_on_work_list.insert(block_idx);
394380
work_list.push(block_idx);
@@ -433,10 +419,10 @@ fn top_stack_or_undef(
433419
}
434420
}
435421

436-
struct Renamer<'a, 'b> {
422+
struct Renamer<'a> {
437423
header: &'a mut ModuleHeader,
438424
types_global_values: &'a mut Vec<Instruction>,
439-
blocks: &'a mut FxIndexMap<LabelId, &'b mut Block>,
425+
blocks: &'a mut [Block],
440426
blocks_with_phi: FxHashSet<usize>,
441427
base_var_type: Word,
442428
var_map: &'a FxHashMap<Word, VarInfo>,
@@ -446,7 +432,7 @@ struct Renamer<'a, 'b> {
446432
rewrite_rules: FxHashMap<Word, Word>,
447433
}
448434

449-
impl Renamer<'_, '_> {
435+
impl Renamer<'_> {
450436
// Returns the phi definition.
451437
fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word {
452438
let from_block_label = self.blocks[from_block].label_id().unwrap();
@@ -568,8 +554,9 @@ impl Renamer<'_, '_> {
568554
}
569555
}
570556

571-
for dest_id in outgoing_edges(self.blocks[block]).collect::<Vec<_>>() {
572-
let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap();
557+
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
558+
// TODO: Don't do this find
559+
let dest_idx = label_to_index(self.blocks, dest_id);
573560
self.rename(dest_idx, Some(block));
574561
}
575562

@@ -579,16 +566,16 @@ impl Renamer<'_, '_> {
579566
}
580567
}
581568

582-
fn remove_nops(blocks: &mut FxIndexMap<LabelId, &mut Block>) {
583-
for block in blocks.values_mut() {
569+
fn remove_nops(blocks: &mut [Block]) {
570+
for block in blocks {
584571
block
585572
.instructions
586573
.retain(|inst| inst.class.opcode != Op::Nop);
587574
}
588575
}
589576

590577
fn remove_old_variables(
591-
blocks: &mut FxIndexMap<LabelId, &mut Block>,
578+
blocks: &mut [Block],
592579
var_maps_and_types: &[(FxHashMap<u32, VarInfo>, u32)],
593580
) {
594581
blocks[0].instructions.retain(|inst| {
@@ -599,7 +586,7 @@ fn remove_old_variables(
599586
.all(|(var_map, _)| !var_map.contains_key(&result_id))
600587
}
601588
});
602-
for block in blocks.values_mut() {
589+
for block in blocks {
603590
block.instructions.retain(|inst| {
604591
!matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain)
605592
|| inst.operands.iter().all(|op| {

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,9 @@ fn id(header: &mut ModuleHeader) -> Word {
8585
result
8686
}
8787

88-
fn apply_rewrite_rules<'a>(
89-
rewrite_rules: &FxHashMap<Word, Word>,
90-
blocks: impl IntoIterator<Item = &'a mut Block>,
91-
) {
88+
fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Block]) {
9289
let all_ids_mut = blocks
93-
.into_iter()
90+
.iter_mut()
9491
.flat_map(|b| b.label.iter_mut().chain(b.instructions.iter_mut()))
9592
.flat_map(|inst| {
9693
inst.result_id

0 commit comments

Comments
 (0)