7
7
//! *references* a rooted thing is also rooted, not the other way around - but that's the basic
8
8
//! concept.
9
9
10
- use rspirv:: dr:: { Function , Instruction , Module } ;
11
- use rspirv:: spirv:: { Op , Word } ;
10
+ use rspirv:: dr:: { Function , Instruction , Module , Operand } ;
11
+ use rspirv:: spirv:: { Op , StorageClass , Word } ;
12
12
use rustc_data_structures:: fx:: FxHashSet ;
13
13
14
14
pub fn dce ( module : & mut Module ) {
@@ -36,8 +36,29 @@ fn spread_roots(module: &Module, rooted: &mut FxHashSet<Word>) -> bool {
36
36
}
37
37
for func in & module. functions {
38
38
if rooted. contains ( & func. def_id ( ) . unwrap ( ) ) {
39
- for inst in func. all_inst_iter ( ) {
40
- any |= root ( inst, rooted) ;
39
+ // NB (Mobius 2021) - since later insts are much more likely to reference
40
+ // earlier insts, by reversing the iteration order, we're more likely to root the
41
+ // entire relevant function at once.
42
+ // See https://github.com/EmbarkStudios/rust-gpu/pull/691#discussion_r681477091
43
+ for inst in func
44
+ . end
45
+ . iter ( )
46
+ . chain (
47
+ func. blocks
48
+ . iter ( )
49
+ . rev ( )
50
+ . flat_map ( |b| b. instructions . iter ( ) . rev ( ) . chain ( b. label . iter ( ) ) ) ,
51
+ )
52
+ . chain ( func. parameters . iter ( ) . rev ( ) )
53
+ . chain ( func. def . iter ( ) )
54
+ {
55
+ if !instruction_is_pure ( inst) {
56
+ any |= root ( inst, rooted) ;
57
+ } else if let Some ( id) = inst. result_id {
58
+ if rooted. contains ( & id) {
59
+ any |= root ( inst, rooted) ;
60
+ }
61
+ }
41
62
}
42
63
}
43
64
}
@@ -90,6 +111,13 @@ fn kill_unrooted(module: &mut Module, rooted: &FxHashSet<Word>) {
90
111
module
91
112
. functions
92
113
. retain ( |f| is_rooted ( f. def . as_ref ( ) . unwrap ( ) , rooted) ) ;
114
+ module. functions . iter_mut ( ) . for_each ( |fun| {
115
+ fun. blocks . iter_mut ( ) . for_each ( |block| {
116
+ block
117
+ . instructions
118
+ . retain ( |inst| !instruction_is_pure ( inst) || is_rooted ( inst, rooted) ) ;
119
+ } ) ;
120
+ } ) ;
93
121
}
94
122
95
123
pub fn dce_phi ( func : & mut Function ) {
@@ -115,3 +143,127 @@ pub fn dce_phi(func: &mut Function) {
115
143
. retain ( |inst| inst. class . opcode != Op :: Phi || used. contains ( & inst. result_id . unwrap ( ) ) ) ;
116
144
}
117
145
}
146
+
147
+ fn instruction_is_pure ( inst : & Instruction ) -> bool {
148
+ use Op :: * ;
149
+ match inst. class . opcode {
150
+ Nop
151
+ | Undef
152
+ | ConstantTrue
153
+ | ConstantFalse
154
+ | Constant
155
+ | ConstantComposite
156
+ | ConstantSampler
157
+ | ConstantNull
158
+ | AccessChain
159
+ | InBoundsAccessChain
160
+ | PtrAccessChain
161
+ | ArrayLength
162
+ | InBoundsPtrAccessChain
163
+ | CompositeConstruct
164
+ | CompositeExtract
165
+ | CopyObject
166
+ | Transpose
167
+ | ConvertFToU
168
+ | ConvertFToS
169
+ | ConvertSToF
170
+ | ConvertUToF
171
+ | UConvert
172
+ | SConvert
173
+ | FConvert
174
+ | QuantizeToF16
175
+ | ConvertPtrToU
176
+ | SatConvertSToU
177
+ | SatConvertUToS
178
+ | ConvertUToPtr
179
+ | PtrCastToGeneric
180
+ | GenericCastToPtr
181
+ | GenericCastToPtrExplicit
182
+ | Bitcast
183
+ | SNegate
184
+ | FNegate
185
+ | IAdd
186
+ | FAdd
187
+ | ISub
188
+ | FSub
189
+ | IMul
190
+ | FMul
191
+ | UDiv
192
+ | SDiv
193
+ | FDiv
194
+ | UMod
195
+ | SRem
196
+ | SMod
197
+ | FRem
198
+ | FMod
199
+ | VectorTimesScalar
200
+ | MatrixTimesScalar
201
+ | VectorTimesMatrix
202
+ | MatrixTimesVector
203
+ | MatrixTimesMatrix
204
+ | OuterProduct
205
+ | Dot
206
+ | IAddCarry
207
+ | ISubBorrow
208
+ | UMulExtended
209
+ | SMulExtended
210
+ | Any
211
+ | All
212
+ | IsNan
213
+ | IsInf
214
+ | IsFinite
215
+ | IsNormal
216
+ | SignBitSet
217
+ | LessOrGreater
218
+ | Ordered
219
+ | Unordered
220
+ | LogicalEqual
221
+ | LogicalNotEqual
222
+ | LogicalOr
223
+ | LogicalAnd
224
+ | LogicalNot
225
+ | Select
226
+ | IEqual
227
+ | INotEqual
228
+ | UGreaterThan
229
+ | SGreaterThan
230
+ | UGreaterThanEqual
231
+ | SGreaterThanEqual
232
+ | ULessThan
233
+ | SLessThan
234
+ | ULessThanEqual
235
+ | SLessThanEqual
236
+ | FOrdEqual
237
+ | FUnordEqual
238
+ | FOrdNotEqual
239
+ | FUnordNotEqual
240
+ | FOrdLessThan
241
+ | FUnordLessThan
242
+ | FOrdGreaterThan
243
+ | FUnordGreaterThan
244
+ | FOrdLessThanEqual
245
+ | FUnordLessThanEqual
246
+ | FOrdGreaterThanEqual
247
+ | FUnordGreaterThanEqual
248
+ | ShiftRightLogical
249
+ | ShiftRightArithmetic
250
+ | ShiftLeftLogical
251
+ | BitwiseOr
252
+ | BitwiseXor
253
+ | BitwiseAnd
254
+ | Not
255
+ | BitFieldInsert
256
+ | BitFieldSExtract
257
+ | BitFieldUExtract
258
+ | BitReverse
259
+ | BitCount
260
+ | Phi
261
+ | SizeOf
262
+ | CopyLogical
263
+ | PtrEqual
264
+ | PtrNotEqual
265
+ | PtrDiff => true ,
266
+ Variable => inst. operands . get ( 0 ) == Some ( & Operand :: StorageClass ( StorageClass :: Function ) ) ,
267
+ _ => false ,
268
+ }
269
+ }
0 commit comments