Skip to content

Commit cda583e

Browse files
committed
Gather all interface OpVariables into OpEntryPoints.
1 parent 86d669b commit cda583e

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//! Passes that pertain to `OpEntryPoint`'s "interface variables".
2+
3+
use crate::linker::ipo::CallGraph;
4+
use indexmap::{IndexMap, IndexSet};
5+
use rspirv::dr::{Module, Operand};
6+
use rspirv::spirv::{Op, StorageClass, Word};
7+
use std::mem;
8+
9+
type Id = Word;
10+
11+
/// Update `OpEntryPoint`s to contain all of the `OpVariable`s they reference,
12+
/// whether directly or through some function in their call graph.
13+
///
14+
/// This is needed for (arguably-not-interface) `Private` in SPIR-V >= 1.4,
15+
/// but also any interface variables declared "out of band" (e.g. via `asm!`).
16+
pub fn gather_all_interface_vars_from_uses(module: &mut Module) {
17+
// Start by mapping out which global (i.e. `OpVariable` or constants) IDs
18+
// can be used to access any interface-relevant `OpVariable`s
19+
// (where "interface-relevant" depends on the version, see comments below).
20+
let mut used_vars_per_global_id: IndexMap<Id, IndexSet<Id>> = IndexMap::new();
21+
let version = module.header.as_ref().unwrap().version();
22+
for inst in &module.types_global_values {
23+
let mut used_vars = IndexSet::new();
24+
25+
// Base case: the global itself is an interface-relevant `OpVariable`.
26+
let interface_relevant_var = inst.class.opcode == Op::Variable && {
27+
if version > (1, 3) {
28+
// SPIR-V >= v1.4 includes all OpVariables in the interface.
29+
true
30+
} else {
31+
let storage_class = inst.operands[0].unwrap_storage_class();
32+
// SPIR-V <= v1.3 only includes Input and Output in the interface.
33+
storage_class == StorageClass::Input || storage_class == StorageClass::Output
34+
}
35+
};
36+
if interface_relevant_var {
37+
used_vars.insert(inst.result_id.unwrap());
38+
}
39+
40+
// Nested constant refs (e.g. `&&&0`) can create chains of `OpVariable`s
41+
// where only the outer-most `OpVariable` may be accessed directly,
42+
// but the interface variables need to include all the nesting levels.
43+
used_vars.extend(
44+
inst.operands
45+
.iter()
46+
.filter_map(|operand| operand.id_ref_any())
47+
.filter_map(|id| used_vars_per_global_id.get(&id))
48+
.flatten(),
49+
);
50+
51+
if !used_vars.is_empty() {
52+
used_vars_per_global_id.insert(inst.result_id.unwrap(), used_vars);
53+
}
54+
}
55+
56+
// Initial uses come from functions directly referencing global instructions.
57+
let mut used_vars_per_fn_idx: Vec<IndexSet<Id>> = module
58+
.functions
59+
.iter()
60+
.map(|func| {
61+
func.all_inst_iter()
62+
.flat_map(|inst| &inst.operands)
63+
.filter_map(|operand| operand.id_ref_any())
64+
.filter_map(|id| used_vars_per_global_id.get(&id))
65+
.flatten()
66+
.copied()
67+
.collect()
68+
})
69+
.collect();
70+
71+
// Uses can then be propagated through the call graph, from callee to caller.
72+
let call_graph = CallGraph::collect(module);
73+
for caller_idx in call_graph.post_order() {
74+
let mut used_vars = mem::take(&mut used_vars_per_fn_idx[caller_idx]);
75+
for &callee_idx in &call_graph.callees[caller_idx] {
76+
used_vars.extend(&used_vars_per_fn_idx[callee_idx]);
77+
}
78+
used_vars_per_fn_idx[caller_idx] = used_vars;
79+
}
80+
81+
// All transitive uses are available, add them to `OpEntryPoint`s.
82+
for (i, entry) in module.entry_points.iter_mut().enumerate() {
83+
assert_eq!(entry.class.opcode, Op::EntryPoint);
84+
let &entry_func_idx = call_graph.entry_points.get_index(i).unwrap();
85+
assert_eq!(
86+
module.functions[entry_func_idx].def_id().unwrap(),
87+
entry.operands[1].unwrap_id_ref()
88+
);
89+
90+
// NOTE(eddyb) it might be better to remove any unused vars, or warn
91+
// the user about their presence, but for now this keeps them around.
92+
let mut interface_vars: IndexSet<Id> = entry.operands[3..]
93+
.iter()
94+
.map(|operand| operand.unwrap_id_ref())
95+
.collect();
96+
97+
interface_vars.extend(&used_vars_per_fn_idx[entry_func_idx]);
98+
99+
entry.operands.truncate(3);
100+
entry
101+
.operands
102+
.extend(interface_vars.iter().map(|&id| Operand::IdRef(id)));
103+
}
104+
}

crates/rustc_codegen_spirv/src/linker/ipo.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub struct CallGraph {
1414
pub entry_points: IndexSet<FuncIdx>,
1515

1616
/// `callees[i].contains(j)` implies `functions[i]` calls `functions[j]`.
17-
callees: Vec<IndexSet<FuncIdx>>,
17+
pub callees: Vec<IndexSet<FuncIdx>>,
1818
}
1919

2020
impl CallGraph {

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod test;
44
mod dce;
55
mod destructure_composites;
66
mod duplicates;
7+
mod entry_interface;
78
mod import_export_link;
89
mod inline;
910
mod ipo;
@@ -270,6 +271,11 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
270271
}
271272
}
272273

274+
{
275+
let _timer = sess.timer("link_gather_all_interface_vars_from_uses");
276+
entry_interface::gather_all_interface_vars_from_uses(&mut output);
277+
}
278+
273279
if opts.spirv_metadata == SpirvMetadata::NameVariables {
274280
let _timer = sess.timer("link_name_variables");
275281
simple_passes::name_variables_pass(&mut output);

0 commit comments

Comments
 (0)