Skip to content

Commit ee05b44

Browse files
committed
[HLSL] Analyze update counter usage
1 parent c79e867 commit ee05b44

File tree

4 files changed

+205
-0
lines changed

4 files changed

+205
-0
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include "llvm/Support/Alignment.h"
1919
#include "llvm/Support/DXILABI.h"
2020

21+
#include <algorithm>
22+
#include <cstdint>
23+
#include <unordered_map>
24+
2125
namespace llvm {
2226
class CallInst;
2327
class DataLayout;
@@ -407,6 +411,69 @@ class DXILResourceTypeMap {
407411
}
408412
};
409413

414+
enum ResourceCounterDirection {
415+
Increment,
416+
Decrement,
417+
Unknown,
418+
};
419+
420+
class DXILResourceCounterDirectionMap {
421+
std::vector<std::pair<dxil::ResourceBindingInfo, ResourceCounterDirection>> CounterDirections;
422+
423+
public:
424+
bool invalidate(Module &M, const PreservedAnalyses &PA,
425+
ModuleAnalysisManager::Invalidator &Inv);
426+
427+
void populate(Module &M, ModuleAnalysisManager &AM);
428+
429+
ResourceCounterDirection operator[](const dxil::ResourceBindingInfo &Info) const {
430+
auto Lower = std::lower_bound(CounterDirections.begin(), CounterDirections.end(), std::pair{Info, ResourceCounterDirection::Unknown}, [](auto lhs, auto rhs){
431+
return lhs.first < rhs.first;
432+
});
433+
434+
if (Lower == CounterDirections.end()) {
435+
return ResourceCounterDirection::Unknown;
436+
}
437+
438+
if (Lower->first != Info) {
439+
return ResourceCounterDirection::Unknown;
440+
}
441+
442+
return Lower->second;
443+
}
444+
};
445+
446+
class DXILResourceCounterDirectionAnalysis
447+
: public AnalysisInfoMixin<DXILResourceCounterDirectionAnalysis> {
448+
friend AnalysisInfoMixin<DXILResourceCounterDirectionAnalysis>;
449+
450+
static AnalysisKey Key;
451+
452+
public:
453+
using Result = DXILResourceCounterDirectionMap;
454+
455+
DXILResourceCounterDirectionMap run(Module &M, ModuleAnalysisManager &AM) {
456+
DXILResourceCounterDirectionMap DRCDM{};
457+
DRCDM.populate(M, AM);
458+
return DRCDM;
459+
}
460+
};
461+
462+
class DXILResourceCounterDirectionWrapperPass : public ImmutablePass {
463+
DXILResourceCounterDirectionMap DRCDM;
464+
465+
virtual void anchor();
466+
467+
public:
468+
static char ID;
469+
DXILResourceCounterDirectionWrapperPass();
470+
471+
DXILResourceCounterDirectionMap &getResourceCounterDirectionMap() { return DRCDM; }
472+
const DXILResourceCounterDirectionMap &getResourceCounterDirectionMap() const { return DRCDM; }
473+
};
474+
475+
ModulePass *createDXILResourceCounterDirectionWrapperPassPass();
476+
410477
class DXILResourceTypeAnalysis
411478
: public AnalysisInfoMixin<DXILResourceTypeAnalysis> {
412479
friend AnalysisInfoMixin<DXILResourceTypeAnalysis>;

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void initializeDCELegacyPassPass(PassRegistry &);
8585
void initializeDXILMetadataAnalysisWrapperPassPass(PassRegistry &);
8686
void initializeDXILMetadataAnalysisWrapperPrinterPass(PassRegistry &);
8787
void initializeDXILResourceBindingWrapperPassPass(PassRegistry &);
88+
void initializeDXILResourceCounterDirectionWrapperPassPass(PassRegistry &);
8889
void initializeDXILResourceTypeWrapperPassPass(PassRegistry &);
8990
void initializeDeadMachineInstructionElimPass(PassRegistry &);
9091
void initializeDebugifyMachineModulePass(PassRegistry &);

llvm/lib/Analysis/DXILResource.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/IR/Module.h"
2020
#include "llvm/InitializePasses.h"
2121
#include "llvm/Support/FormatVariadic.h"
22+
#include <algorithm>
2223

2324
#define DEBUG_TYPE "dxil-resource"
2425

@@ -818,8 +819,88 @@ DXILBindingMap::findByUse(const Value *Key) const {
818819

819820
//===----------------------------------------------------------------------===//
820821

822+
static bool isUpdateCounterIntrinsic(Function &F) {
823+
return F.getIntrinsicID() == Intrinsic::dx_resource_updatecounter;
824+
}
825+
826+
void DXILResourceCounterDirectionMap::populate(Module &M, ModuleAnalysisManager &AM) {
827+
DXILBindingMap &DBM = AM.getResult<DXILResourceBindingAnalysis>(M);
828+
CounterDirections.clear();
829+
830+
for (Function &F : M.functions()) {
831+
if (!isUpdateCounterIntrinsic(F))
832+
continue;
833+
834+
for (const User *U : F.users()) {
835+
const CallInst *CI = dyn_cast<CallInst>(U);
836+
assert(CI && "Users of dx_resource_updateCounter must be call instrs");
837+
838+
// Determine if the use is an increment or decrement
839+
Value *CountArg = CI->getArgOperand(1);
840+
ConstantInt *CountValue = dyn_cast<ConstantInt>(CountArg);
841+
int64_t CountLiteral = CountValue->getSExtValue();
842+
843+
ResourceCounterDirection Direction = ResourceCounterDirection::Unknown;
844+
if (CountLiteral > 0) {
845+
Direction = ResourceCounterDirection::Increment;
846+
}
847+
if (CountLiteral < 0) {
848+
Direction = ResourceCounterDirection::Decrement;
849+
}
850+
851+
852+
// Collect all potential creation points for the handle arg
853+
Value *HandleArg = CI->getArgOperand(0);
854+
SmallVector<dxil::ResourceBindingInfo> RBInfos = DBM.findByUse(HandleArg);
855+
for(const dxil::ResourceBindingInfo RBInfo : RBInfos) {
856+
CounterDirections.emplace_back(RBInfo, Direction);
857+
}
858+
}
859+
}
860+
861+
// An entry that is not in the map is considered unknown so its wasted
862+
// overhead and increased complexity to keep it so it should be removed.
863+
const auto RemoveEnd = std::remove_if(CounterDirections.begin(), CounterDirections.end(), [](const auto& Item) {
864+
return Item.second == ResourceCounterDirection::Unknown;
865+
});
866+
867+
// Order for fast lookup
868+
std::sort(CounterDirections.begin(), RemoveEnd);
869+
870+
// Remove the duplicate entries. Since direction is considered for equality
871+
// a unique resource with more than one direction will not be deduped.
872+
const auto UniqueEnd = std::unique(CounterDirections.begin(), RemoveEnd);
873+
874+
// Actually erase the items invalidated by remove_if + unique
875+
CounterDirections.erase(UniqueEnd, CounterDirections.end());
876+
877+
// If any duplicate entries still exist at this point then it must be a
878+
// resource that was both incremented and decremented which is not allowed.
879+
const auto DuplicateEntry = std::adjacent_find(CounterDirections.begin(), CounterDirections.end(), [](const auto& LHS, const auto& RHS){
880+
return LHS.first == RHS.first;
881+
});
882+
if (DuplicateEntry == CounterDirections.end())
883+
return;
884+
885+
// TODO: Emit an error message
886+
assert(CounterDirections.size() == 1 && "dups found");
887+
}
888+
889+
bool DXILResourceCounterDirectionMap::invalidate(Module &M, const PreservedAnalyses &PA,
890+
ModuleAnalysisManager::Invalidator &Inv) {
891+
// Passes that introduce resource types must explicitly invalidate this pass.
892+
// auto PAC = PA.getChecker<DXILResourceTypeAnalysis>();
893+
// return !PAC.preservedWhenStateless();
894+
return false;
895+
}
896+
897+
void DXILResourceCounterDirectionWrapperPass::anchor() {}
898+
899+
//===----------------------------------------------------------------------===//
900+
821901
AnalysisKey DXILResourceTypeAnalysis::Key;
822902
AnalysisKey DXILResourceBindingAnalysis::Key;
903+
AnalysisKey DXILResourceCounterDirectionAnalysis::Key;
823904

824905
DXILBindingMap DXILResourceBindingAnalysis::run(Module &M,
825906
ModuleAnalysisManager &AM) {
@@ -838,6 +919,13 @@ DXILResourceBindingPrinterPass::run(Module &M, ModuleAnalysisManager &AM) {
838919
return PreservedAnalyses::all();
839920
}
840921

922+
INITIALIZE_PASS(DXILResourceCounterDirectionWrapperPass, "dxil-resource-counter",
923+
"DXIL Resource Counter Analysis", false, true)
924+
925+
DXILResourceCounterDirectionWrapperPass::DXILResourceCounterDirectionWrapperPass() : ImmutablePass(ID) {
926+
initializeDXILResourceTypeWrapperPassPass(*PassRegistry::getPassRegistry());
927+
}
928+
841929
void DXILResourceTypeWrapperPass::anchor() {}
842930

843931
DXILResourceTypeWrapperPass::DXILResourceTypeWrapperPass() : ImmutablePass(ID) {
@@ -847,11 +935,16 @@ DXILResourceTypeWrapperPass::DXILResourceTypeWrapperPass() : ImmutablePass(ID) {
847935
INITIALIZE_PASS(DXILResourceTypeWrapperPass, "dxil-resource-type",
848936
"DXIL Resource Type Analysis", false, true)
849937
char DXILResourceTypeWrapperPass::ID = 0;
938+
char DXILResourceCounterDirectionWrapperPass::ID = 0;
850939

851940
ModulePass *llvm::createDXILResourceTypeWrapperPassPass() {
852941
return new DXILResourceTypeWrapperPass();
853942
}
854943

944+
ModulePass *llvm::createDXILResourceCounterDirectionWrapperPassPass() {
945+
return new DXILResourceCounterDirectionWrapperPass();
946+
}
947+
855948
DXILResourceBindingWrapperPass::DXILResourceBindingWrapperPass()
856949
: ModulePass(ID) {
857950
initializeDXILResourceBindingWrapperPassPass(

llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class UniqueResourceFromUseTest : public testing::Test {
3535
PB->registerModuleAnalyses(*MAM);
3636
MAM->registerPass([&] { return DXILResourceTypeAnalysis(); });
3737
MAM->registerPass([&] { return DXILResourceBindingAnalysis(); });
38+
MAM->registerPass([&] { return DXILResourceCounterDirectionAnalysis(); });
3839
}
3940

4041
virtual void TearDown() {
@@ -280,4 +281,47 @@ declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", flo
280281
}
281282
}
282283

284+
TEST_F(UniqueResourceFromUseTest, TestResourceCounter) {
285+
StringRef Assembly = R"(
286+
define void @main() {
287+
entry:
288+
%handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
289+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
290+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
291+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
292+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
293+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
294+
call i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %handle, i8 -1)
295+
call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
296+
ret void
297+
}
298+
299+
declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
300+
declare i32 @llvm.dx.resource.updatecounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i8)
301+
declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
302+
)";
303+
304+
LLVMContext Context;
305+
SMDiagnostic Error;
306+
auto M = parseAssemblyString(Assembly, Error, Context);
307+
ASSERT_TRUE(M) << "Bad assembly?";
308+
309+
const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
310+
const DXILResourceCounterDirectionMap &DCDM = MAM->getResult<DXILResourceCounterDirectionAnalysis>(*M);
311+
312+
for (const Function &F : M->functions()) {
313+
if (F.getName() != "a.func") {
314+
continue;
315+
}
316+
317+
for (const User *U : F.users()) {
318+
const CallInst *CI = cast<CallInst>(U);
319+
const Value *Handle = CI->getArgOperand(0);
320+
const auto Bindings = DBM.findByUse(Handle);
321+
ASSERT_EQ(Bindings.size(), 1u);
322+
ASSERT_EQ(DCDM[Bindings.front()], ResourceCounterDirection::Decrement);
323+
}
324+
}
325+
}
326+
283327
} // namespace

0 commit comments

Comments
 (0)