Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 81 additions & 87 deletions lib/HLSL/DxilScalarizeVectorIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,85 +28,9 @@
using namespace llvm;
using namespace hlsl;

static bool scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL,
CallInst *CI);
static bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
CallInst *CI);
static bool scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI);
static bool scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI);
static bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI);
static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI);
namespace {

class DxilScalarizeVectorIntrinsics : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilScalarizeVectorIntrinsics() : ModulePass(ID) {}

StringRef getPassName() const override {
return "DXIL scalarize vector load/stores";
}

bool runOnModule(Module &M) override {
DxilModule &DM = M.GetOrCreateDxilModule();
// Shader Model 6.9 allows native vectors and doesn't need this pass.
if (DM.GetShaderModel()->IsSM69Plus())
return false;

bool Changed = false;

hlsl::OP *HlslOP = DM.GetOP();

// Iterate and scalarize native vector loads, stores, and other intrinsics.
for (auto F = M.functions().begin(); F != M.functions().end();) {
Function *Func = &*(F++);
DXIL::OpCodeClass OpClass;
if (!HlslOP->GetOpCodeClass(Func, OpClass))
continue;

const bool CouldRewrite =
(Func->getReturnType()->isVectorTy() ||
OpClass == DXIL::OpCodeClass::RawBufferVectorLoad ||
OpClass == DXIL::OpCodeClass::RawBufferVectorStore ||
OpClass == DXIL::OpCodeClass::VectorReduce ||
OpClass == DXIL::OpCodeClass::Dot ||
OpClass == DXIL::OpCodeClass::WaveMatch);
if (!CouldRewrite)
continue;

for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
CallInst *CI = cast<CallInst>(*(U++));

// Handle DXIL operations with complex signatures separately
switch (OpClass) {
case DXIL::OpCodeClass::RawBufferVectorLoad:
Changed |= scalarizeVectorLoad(HlslOP, M.getDataLayout(), CI);
continue;
case DXIL::OpCodeClass::RawBufferVectorStore:
Changed |= scalarizeVectorStore(HlslOP, M.getDataLayout(), CI);
continue;
case DXIL::OpCodeClass::VectorReduce:
Changed |= scalarizeVectorReduce(HlslOP, CI);
continue;
case DXIL::OpCodeClass::Dot:
Changed |= scalarizeVectorDot(HlslOP, CI);
continue;
case DXIL::OpCodeClass::WaveMatch:
Changed |= scalarizeVectorWaveMatch(HlslOP, CI);
continue;
default:
break;
}

// Handle DXIL Ops with vector return matching the vector params
if (Func->getReturnType()->isVectorTy())
Changed |= scalarizeVectorIntrinsic(HlslOP, CI);
}
}
return Changed;
}
};

static unsigned GetRawBufferMask(unsigned NumComponents) {
unsigned GetRawBufferMask(unsigned NumComponents) {
switch (NumComponents) {
case 0:
return 0;
Expand All @@ -123,8 +47,7 @@ static unsigned GetRawBufferMask(unsigned NumComponents) {
return DXIL::kCompMask_All;
}

static bool scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL,
CallInst *CI) {
bool scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL, CallInst *CI) {
IRBuilder<> Builder(CI);
// Collect the information required to break this into scalar ops from args.
DxilInst_RawBufferVectorLoad VecLd(CI);
Expand Down Expand Up @@ -177,7 +100,7 @@ static bool scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL,
// Replace users of the vector extracted from the vector load resret.
Value *Status = nullptr;
for (auto CU = CI->user_begin(), CE = CI->user_end(); CU != CE;) {
auto EV = cast<ExtractValueInst>(*(CU++));
auto *EV = cast<ExtractValueInst>(*(CU++));
unsigned Ix = EV->getIndices()[0];
if (Ix == 0) {
// Handle value uses.
Expand All @@ -195,8 +118,8 @@ static bool scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL,
return true;
}

static bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
CallInst *CI) {
bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
CallInst *CI) {
IRBuilder<> Builder(CI);
// Collect the information required to break this into scalar ops from args.
DxilInst_RawBufferVectorStore VecSt(CI);
Expand Down Expand Up @@ -259,7 +182,7 @@ static bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
return true;
}

static bool scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI) {
bool scalarizeVectorReduce(CallInst *CI) {
IRBuilder<> Builder(CI);

OP::OpCode ReduceOp = OP::getOpCode(CI);
Expand Down Expand Up @@ -288,7 +211,7 @@ static bool scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI) {
return true;
}

static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI) {
bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI) {
IRBuilder<> Builder(CI);
OP::OpCode Opcode = OP::getOpCode(CI);
Value *VecArg = CI->getArgOperand(1);
Expand Down Expand Up @@ -344,7 +267,7 @@ static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI) {
}

// Scalarize vectorized dot product
static bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI) {
bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI) {
IRBuilder<> Builder(CI);

Value *AVecArg = CI->getArgOperand(1);
Expand Down Expand Up @@ -406,7 +329,7 @@ static bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI) {
// Scalarize native vector operation represented by `CI`, generating
// scalar calls for each element of the its vector parameters.
// Use `HlslOP` to retrieve the associated scalar op function.
static bool scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI) {
bool scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI) {

IRBuilder<> Builder(CI);
VectorType *VT = cast<VectorType>(CI->getType());
Expand Down Expand Up @@ -440,6 +363,77 @@ static bool scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI) {
return true;
}

} // namespace

class DxilScalarizeVectorIntrinsics : public ModulePass {
public:
static char ID; // Pass identification, replacement for typeid
explicit DxilScalarizeVectorIntrinsics() : ModulePass(ID) {}

StringRef getPassName() const override {
return "DXIL scalarize vector load/stores";
}

bool runOnModule(Module &M) override {
DxilModule &DM = M.GetOrCreateDxilModule();
// Shader Model 6.9 allows native vectors and doesn't need this pass.
if (DM.GetShaderModel()->IsSM69Plus())
return false;

bool Changed = false;

hlsl::OP *HlslOP = DM.GetOP();

// Iterate and scalarize native vector loads, stores, and other intrinsics.
for (auto F = M.functions().begin(); F != M.functions().end();) {
Function *Func = &*(F++);
DXIL::OpCodeClass OpClass;
if (!HlslOP->GetOpCodeClass(Func, OpClass))
continue;

const bool CouldRewrite =
(Func->getReturnType()->isVectorTy() ||
OpClass == DXIL::OpCodeClass::RawBufferVectorLoad ||
OpClass == DXIL::OpCodeClass::RawBufferVectorStore ||
OpClass == DXIL::OpCodeClass::VectorReduce ||
OpClass == DXIL::OpCodeClass::Dot ||
OpClass == DXIL::OpCodeClass::WaveMatch);
if (!CouldRewrite)
continue;

for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
CallInst *CI = cast<CallInst>(*(U++));

// Handle DXIL operations with complex signatures separately
switch (OpClass) {
case DXIL::OpCodeClass::RawBufferVectorLoad:
Changed |= scalarizeVectorLoad(HlslOP, M.getDataLayout(), CI);
continue;
case DXIL::OpCodeClass::RawBufferVectorStore:
Changed |= scalarizeVectorStore(HlslOP, M.getDataLayout(), CI);
continue;
case DXIL::OpCodeClass::VectorReduce:
Changed |= scalarizeVectorReduce(CI);
continue;
case DXIL::OpCodeClass::Dot:
Changed |= scalarizeVectorDot(HlslOP, CI);
continue;
case DXIL::OpCodeClass::WaveMatch:
Changed |= scalarizeVectorWaveMatch(HlslOP, CI);
continue;
default:
break;
}

// Handle DXIL Ops with vector return matching the vector params
if (Func->getReturnType()->isVectorTy())
Changed |= scalarizeVectorIntrinsic(HlslOP, CI);
}
}
return Changed;
}
};

char DxilScalarizeVectorIntrinsics::ID = 0;

ModulePass *llvm::createDxilScalarizeVectorIntrinsicsPass() {
Expand Down
2 changes: 1 addition & 1 deletion lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2606,7 +2606,7 @@ Value *TranslateDot(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
Type *EltTy = Ty->getScalarType();

// SM6.9 introduced a DXIL operation for vectorized dot product
// The operation is only advantageous for vect size>1, vec1s will be
// The operation is only advantageous for vec size>1, vec1s will be
// lowered to a single Mul.
if (hlslOP->GetModule()->GetHLModule().GetShaderModel()->IsSM69Plus() &&
EltTy->isFloatingPointTy() && Ty->getVectorNumElements() > 1) {
Expand Down