1616#include " llvm/Analysis/DXILMetadataAnalysis.h"
1717#include " llvm/Analysis/DXILResource.h"
1818#include " llvm/CodeGen/Passes.h"
19+ #include " llvm/IR/Constant.h"
1920#include " llvm/IR/DiagnosticInfo.h"
2021#include " llvm/IR/IRBuilder.h"
2122#include " llvm/IR/Instruction.h"
2425#include " llvm/IR/IntrinsicsDirectX.h"
2526#include " llvm/IR/Module.h"
2627#include " llvm/IR/PassManager.h"
28+ #include " llvm/IR/Use.h"
2729#include " llvm/InitializePasses.h"
2830#include " llvm/Pass.h"
2931#include " llvm/Support/ErrorHandling.h"
@@ -42,6 +44,7 @@ class OpLowerer {
4244 DXILResourceTypeMap &DRTM;
4345 const ModuleMetadataInfo &MMDI;
4446 SmallVector<CallInst *> CleanupCasts;
47+ Function *CleanupNURI = nullptr ;
4548
4649public:
4750 OpLowerer (Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM,
@@ -195,6 +198,21 @@ class OpLowerer {
195198 CleanupCasts.clear ();
196199 }
197200
201+ void cleanupNonUniformResourceIndexCalls () {
202+ // Replace all NonUniformResourceIndex calls with their argument.
203+ if (!CleanupNURI)
204+ return ;
205+ for (User *U : make_early_inc_range (CleanupNURI->users ())) {
206+ CallInst *CI = dyn_cast<CallInst>(U);
207+ if (!CI)
208+ continue ;
209+ CI->replaceAllUsesWith (CI->getArgOperand (0 ));
210+ CI->eraseFromParent ();
211+ }
212+ CleanupNURI->eraseFromParent ();
213+ CleanupNURI = nullptr ;
214+ }
215+
198216 // Remove the resource global associated with the handleFromBinding call
199217 // instruction and their uses as they aren't needed anymore.
200218 // TODO: We should verify that all the globals get removed.
@@ -229,6 +247,31 @@ class OpLowerer {
229247 NameGlobal->removeFromParent ();
230248 }
231249
250+ bool hasNonUniformIndex (Value *IndexOp) {
251+ if (isa<llvm::Constant>(IndexOp))
252+ return false ;
253+
254+ SmallVector<Value *> WorkList;
255+ WorkList.push_back (IndexOp);
256+
257+ while (!WorkList.empty ()) {
258+ Value *V = WorkList.pop_back_val ();
259+ if (auto *CI = dyn_cast<CallInst>(V)) {
260+ if (CI->getCalledFunction ()->getIntrinsicID () ==
261+ Intrinsic::dx_resource_nonuniformindex)
262+ return true ;
263+ }
264+ if (auto *U = llvm::dyn_cast<llvm::User>(V)) {
265+ for (llvm::Value *Op : U->operands ()) {
266+ if (isa<llvm::Constant>(Op))
267+ continue ;
268+ WorkList.push_back (Op);
269+ }
270+ }
271+ }
272+ return false ;
273+ }
274+
232275 [[nodiscard]] bool lowerToCreateHandle (Function &F) {
233276 IRBuilder<> &IRB = OpBuilder.getIRB ();
234277 Type *Int8Ty = IRB.getInt8Ty ();
@@ -250,13 +293,12 @@ class OpLowerer {
250293 IndexOp = IRB.CreateAdd (IndexOp,
251294 ConstantInt::get (Int32Ty, Binding.LowerBound ));
252295
253- // FIXME: The last argument is a NonUniform flag which needs to be set
254- // based on resource analysis.
255- // https://github.com/llvm/llvm-project/issues/155701
296+ bool HasNonUniformIndex =
297+ (Binding.Size == 1 ) ? false : hasNonUniformIndex (IndexOp);
256298 std::array<Value *, 4 > Args{
257299 ConstantInt::get (Int8Ty, llvm::to_underlying (RC)),
258300 ConstantInt::get (Int32Ty, Binding.RecordID ), IndexOp,
259- ConstantInt::get (Int1Ty, false )};
301+ ConstantInt::get (Int1Ty, HasNonUniformIndex )};
260302 Expected<CallInst *> OpCall =
261303 OpBuilder.tryCreateOp (OpCode::CreateHandle, Args, CI->getName ());
262304 if (Error E = OpCall.takeError ())
@@ -300,11 +342,10 @@ class OpLowerer {
300342 : Binding.LowerBound + Binding.Size - 1 ;
301343 Constant *ResBind = OpBuilder.getResBind (Binding.LowerBound , UpperBound,
302344 Binding.Space , RC);
303- // FIXME: The last argument is a NonUniform flag which needs to be set
304- // based on resource analysis.
305- // https://github.com/llvm/llvm-project/issues/155701
306- Constant *NonUniform = ConstantInt::get (Int1Ty, false );
307- std::array<Value *, 3 > BindArgs{ResBind, IndexOp, NonUniform};
345+ bool NonUniformIndex =
346+ (Binding.Size == 1 ) ? false : hasNonUniformIndex (IndexOp);
347+ Constant *NonUniformOp = ConstantInt::get (Int1Ty, NonUniformIndex);
348+ std::array<Value *, 3 > BindArgs{ResBind, IndexOp, NonUniformOp};
308349 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp (
309350 OpCode::CreateHandleFromBinding, BindArgs, CI->getName ());
310351 if (Error E = OpBind.takeError ())
@@ -868,6 +909,11 @@ class OpLowerer {
868909 case Intrinsic::dx_resource_getpointer:
869910 HasErrors |= lowerGetPointer (F);
870911 break ;
912+ case Intrinsic::dx_resource_nonuniformindex:
913+ assert (!CleanupNURI &&
914+ " overloaded llvm.dx.resource.nonuniformindex intrinsics?" );
915+ CleanupNURI = &F;
916+ break ;
871917 case Intrinsic::dx_resource_load_typedbuffer:
872918 HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ true );
873919 break ;
@@ -908,8 +954,10 @@ class OpLowerer {
908954 }
909955 Updated = true ;
910956 }
911- if (Updated && !HasErrors)
957+ if (Updated && !HasErrors) {
912958 cleanupHandleCasts ();
959+ cleanupNonUniformResourceIndexCalls ();
960+ }
913961
914962 return Updated;
915963 }
0 commit comments