16
16
#include " llvm/Analysis/DXILMetadataAnalysis.h"
17
17
#include " llvm/Analysis/DXILResource.h"
18
18
#include " llvm/CodeGen/Passes.h"
19
+ #include " llvm/IR/Constant.h"
19
20
#include " llvm/IR/DiagnosticInfo.h"
20
21
#include " llvm/IR/IRBuilder.h"
21
22
#include " llvm/IR/Instruction.h"
24
25
#include " llvm/IR/IntrinsicsDirectX.h"
25
26
#include " llvm/IR/Module.h"
26
27
#include " llvm/IR/PassManager.h"
28
+ #include " llvm/IR/Use.h"
27
29
#include " llvm/InitializePasses.h"
28
30
#include " llvm/Pass.h"
29
31
#include " llvm/Support/ErrorHandling.h"
@@ -42,6 +44,7 @@ class OpLowerer {
42
44
DXILResourceTypeMap &DRTM;
43
45
const ModuleMetadataInfo &MMDI;
44
46
SmallVector<CallInst *> CleanupCasts;
47
+ Function *CleanupNURI = nullptr ;
45
48
46
49
public:
47
50
OpLowerer (Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM,
@@ -195,6 +198,21 @@ class OpLowerer {
195
198
CleanupCasts.clear ();
196
199
}
197
200
201
+ void cleanupNonUniformResourceIndexCalls () {
202
+ // Replace all NonUniformResourceIndex calls with their argument.
203
+ if (!CleanupNURI)
204
+ return ;
205
+ for (User *U : make_early_inc_range (CleanupNURI->users ())) {
206
+ CallInst *CI = dyn_cast<CallInst>(U);
207
+ if (!CI)
208
+ continue ;
209
+ CI->replaceAllUsesWith (CI->getArgOperand (0 ));
210
+ CI->eraseFromParent ();
211
+ }
212
+ CleanupNURI->eraseFromParent ();
213
+ CleanupNURI = nullptr ;
214
+ }
215
+
198
216
// Remove the resource global associated with the handleFromBinding call
199
217
// instruction and their uses as they aren't needed anymore.
200
218
// TODO: We should verify that all the globals get removed.
@@ -229,6 +247,31 @@ class OpLowerer {
229
247
NameGlobal->removeFromParent ();
230
248
}
231
249
250
+ bool hasNonUniformIndex (Value *IndexOp) {
251
+ if (isa<llvm::Constant>(IndexOp))
252
+ return false ;
253
+
254
+ SmallVector<Value *> WorkList;
255
+ WorkList.push_back (IndexOp);
256
+
257
+ while (!WorkList.empty ()) {
258
+ Value *V = WorkList.pop_back_val ();
259
+ if (auto *CI = dyn_cast<CallInst>(V)) {
260
+ if (CI->getCalledFunction ()->getIntrinsicID () ==
261
+ Intrinsic::dx_resource_nonuniformindex)
262
+ return true ;
263
+ }
264
+ if (auto *U = llvm::dyn_cast<llvm::User>(V)) {
265
+ for (llvm::Value *Op : U->operands ()) {
266
+ if (isa<llvm::Constant>(Op))
267
+ continue ;
268
+ WorkList.push_back (Op);
269
+ }
270
+ }
271
+ }
272
+ return false ;
273
+ }
274
+
232
275
[[nodiscard]] bool lowerToCreateHandle (Function &F) {
233
276
IRBuilder<> &IRB = OpBuilder.getIRB ();
234
277
Type *Int8Ty = IRB.getInt8Ty ();
@@ -250,13 +293,12 @@ class OpLowerer {
250
293
IndexOp = IRB.CreateAdd (IndexOp,
251
294
ConstantInt::get (Int32Ty, Binding.LowerBound ));
252
295
253
- // FIXME: The last argument is a NonUniform flag which needs to be set
254
- // based on resource analysis.
255
- // https://github.com/llvm/llvm-project/issues/155701
296
+ bool HasNonUniformIndex =
297
+ (Binding.Size == 1 ) ? false : hasNonUniformIndex (IndexOp);
256
298
std::array<Value *, 4 > Args{
257
299
ConstantInt::get (Int8Ty, llvm::to_underlying (RC)),
258
300
ConstantInt::get (Int32Ty, Binding.RecordID ), IndexOp,
259
- ConstantInt::get (Int1Ty, false )};
301
+ ConstantInt::get (Int1Ty, HasNonUniformIndex )};
260
302
Expected<CallInst *> OpCall =
261
303
OpBuilder.tryCreateOp (OpCode::CreateHandle, Args, CI->getName ());
262
304
if (Error E = OpCall.takeError ())
@@ -300,11 +342,10 @@ class OpLowerer {
300
342
: Binding.LowerBound + Binding.Size - 1 ;
301
343
Constant *ResBind = OpBuilder.getResBind (Binding.LowerBound , UpperBound,
302
344
Binding.Space , RC);
303
- // FIXME: The last argument is a NonUniform flag which needs to be set
304
- // based on resource analysis.
305
- // https://github.com/llvm/llvm-project/issues/155701
306
- Constant *NonUniform = ConstantInt::get (Int1Ty, false );
307
- std::array<Value *, 3 > BindArgs{ResBind, IndexOp, NonUniform};
345
+ bool NonUniformIndex =
346
+ (Binding.Size == 1 ) ? false : hasNonUniformIndex (IndexOp);
347
+ Constant *NonUniformOp = ConstantInt::get (Int1Ty, NonUniformIndex);
348
+ std::array<Value *, 3 > BindArgs{ResBind, IndexOp, NonUniformOp};
308
349
Expected<CallInst *> OpBind = OpBuilder.tryCreateOp (
309
350
OpCode::CreateHandleFromBinding, BindArgs, CI->getName ());
310
351
if (Error E = OpBind.takeError ())
@@ -868,6 +909,11 @@ class OpLowerer {
868
909
case Intrinsic::dx_resource_getpointer:
869
910
HasErrors |= lowerGetPointer (F);
870
911
break ;
912
+ case Intrinsic::dx_resource_nonuniformindex:
913
+ assert (!CleanupNURI &&
914
+ " overloaded llvm.dx.resource.nonuniformindex intrinsics?" );
915
+ CleanupNURI = &F;
916
+ break ;
871
917
case Intrinsic::dx_resource_load_typedbuffer:
872
918
HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ true );
873
919
break ;
@@ -908,8 +954,10 @@ class OpLowerer {
908
954
}
909
955
Updated = true ;
910
956
}
911
- if (Updated && !HasErrors)
957
+ if (Updated && !HasErrors) {
912
958
cleanupHandleCasts ();
959
+ cleanupNonUniformResourceIndexCalls ();
960
+ }
913
961
914
962
return Updated;
915
963
}
0 commit comments