Skip to content

Commit 1eb83c7

Browse files
author
Greg Roth
authored
Allow native vectors for LLVM operations (microsoft#7155)
Disables various forms of scalarization and vector elimination to permit vectors to pass through to final DXIL when used in native LLVM operations and loading/storing. Introduces a few vector manipulation llvm instructions to DXIL allowing for them to appear in output DXIL. Skips passes for 6.9 that scalarize, convert to arrays, or otherwise eliminate vectors. This eliminates the element-by-element extraction, application, and reconstitution of the vectors to operators. In many cases, this required plumbing the shader model information to passes that didn't have it before and also the recreation of dxil version information from metadata where necessary. Many changes were needed for the MatrixBitcastLower pass related to linking to avoid converting matrix vectors, but also to perform the conversion if a shader was compiled for 6.9+, but then linked to a earlier target. This now adapts to the linker target to either preserve vectors for 6.9 or arrays for previous versions. This requires running the DynamicIndexing VectorToArray pass during linking since 6_x and 6_9+ will fail to run this in the initial compile, but will still need to lower vectors to arrays. This required making the pass particularly robust to different sources of version information as compiling, linking, and running optimization in isolation each require retrieval from a different source. The latter two sources are facilitated with a dxilutil function. Ternary conditional/select operators were element extracted in codegen. Removing this allows 6.9 to preserve the vectors, but also maintains behavior for previous shader models because the operations get scalarized later anyway. This was in the region of work to allow short circuiting, but the effect of that is to introduce the select and skip the later code that implements short circuiting for supported cases. Test confirm that no short circuiting is introduced for native vectors. Adds extensive tests for these operations using different types and sizes and testing them appropriately. Booleans produce significantly different code, so they get their own test. Vec1s have some special treatment as they are not allowed in final dxil, so they still need to be scalarized. This requires value specific conditionals in transformation passes. Testing confirms that this is done. Fixes microsoft#7123
1 parent 33bc44a commit 1eb83c7

30 files changed

+5242
-78
lines changed

include/dxc/DXIL/DxilInstructions.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,42 @@ struct LlvmInst_VAArg {
645645
bool isAllowed() const { return false; }
646646
};
647647

648+
/// This instruction extracts from vector
649+
struct LlvmInst_ExtractElement {
650+
llvm::Instruction *Instr;
651+
// Construction and identification
652+
LlvmInst_ExtractElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
653+
operator bool() const {
654+
return Instr->getOpcode() == llvm::Instruction::ExtractElement;
655+
}
656+
// Validation support
657+
bool isAllowed() const { return true; }
658+
};
659+
660+
/// This instruction inserts into vector
661+
struct LlvmInst_InsertElement {
662+
llvm::Instruction *Instr;
663+
// Construction and identification
664+
LlvmInst_InsertElement(llvm::Instruction *pInstr) : Instr(pInstr) {}
665+
operator bool() const {
666+
return Instr->getOpcode() == llvm::Instruction::InsertElement;
667+
}
668+
// Validation support
669+
bool isAllowed() const { return true; }
670+
};
671+
672+
/// This instruction Shuffle two vectors
673+
struct LlvmInst_ShuffleVector {
674+
llvm::Instruction *Instr;
675+
// Construction and identification
676+
LlvmInst_ShuffleVector(llvm::Instruction *pInstr) : Instr(pInstr) {}
677+
operator bool() const {
678+
return Instr->getOpcode() == llvm::Instruction::ShuffleVector;
679+
}
680+
// Validation support
681+
bool isAllowed() const { return true; }
682+
};
683+
648684
/// This instruction extracts from aggregate
649685
struct LlvmInst_ExtractValue {
650686
llvm::Instruction *Instr;

include/dxc/DXIL/DxilMetadataHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ class DxilMDHelper {
427427
// Dxil version.
428428
void EmitDxilVersion(unsigned Major, unsigned Minor);
429429
void LoadDxilVersion(unsigned &Major, unsigned &Minor);
430+
static bool LoadDxilVersion(const llvm::Module *pModule, unsigned &Major,
431+
unsigned &Minor);
430432

431433
// Validator version.
432434
void EmitValidatorVersion(unsigned Major, unsigned Minor);

include/dxc/DXIL/DxilUtil.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ bool DeleteDeadAllocas(llvm::Function &F);
223223
llvm::Value *GEPIdxToOffset(llvm::GetElementPtrInst *GEP,
224224
llvm::IRBuilder<> &Builder, hlsl::OP *OP,
225225
const llvm::DataLayout &DL);
226+
227+
// Passes back Dxil version of the given module on true return.
228+
bool LoadDxilVersion(const llvm::Module *M, unsigned &Major, unsigned &Minor);
229+
226230
} // namespace dxilutil
227231

228232
} // namespace hlsl

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,28 @@ void DxilMDHelper::EmitDxilVersion(unsigned Major, unsigned Minor) {
177177
pDxilVersionMD->addOperand(MDNode::get(m_Ctx, MDVals));
178178
}
179179

180-
void DxilMDHelper::LoadDxilVersion(unsigned &Major, unsigned &Minor) {
181-
NamedMDNode *pDxilVersionMD = m_pModule->getNamedMetadata(kDxilVersionMDName);
182-
IFTBOOL(pDxilVersionMD != nullptr, DXC_E_INCORRECT_DXIL_METADATA);
183-
IFTBOOL(pDxilVersionMD->getNumOperands() == 1, DXC_E_INCORRECT_DXIL_METADATA);
180+
// Load dxil version from metadata contained in pModule.
181+
// Returns true and passes result through
182+
// the dxil major/minor version params if valid.
183+
// Returns false if metadata is missing or invalid.
184+
bool DxilMDHelper::LoadDxilVersion(const Module *pModule, unsigned &Major,
185+
unsigned &Minor) {
186+
NamedMDNode *pDxilVersionMD = pModule->getNamedMetadata(kDxilVersionMDName);
187+
IFRBOOL(pDxilVersionMD != nullptr, false);
188+
IFRBOOL(pDxilVersionMD->getNumOperands() == 1, false);
184189

185190
MDNode *pVersionMD = pDxilVersionMD->getOperand(0);
186-
IFTBOOL(pVersionMD->getNumOperands() == kDxilVersionNumFields,
187-
DXC_E_INCORRECT_DXIL_METADATA);
191+
IFRBOOL(pVersionMD->getNumOperands() == kDxilVersionNumFields, false);
188192

189193
Major = ConstMDToUint32(pVersionMD->getOperand(kDxilVersionMajorIdx));
190194
Minor = ConstMDToUint32(pVersionMD->getOperand(kDxilVersionMinorIdx));
195+
196+
return true;
197+
}
198+
199+
void DxilMDHelper::LoadDxilVersion(unsigned &Major, unsigned &Minor) {
200+
IFTBOOL(LoadDxilVersion(m_pModule, Major, Minor),
201+
DXC_E_INCORRECT_DXIL_METADATA);
191202
}
192203

193204
//

lib/DXIL/DxilUtil.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,5 +1415,18 @@ bool DeleteDeadAllocas(llvm::Function &F) {
14151415
return Changed;
14161416
}
14171417

1418+
// Retrieve dxil version in the given module.
1419+
// Where the module doesn't already have a Dxil module,
1420+
// it identifies and returns the version info from the metatdata.
1421+
// Returns false where none of that works, but that shouldn't happen much.
1422+
bool LoadDxilVersion(const Module *M, unsigned &Major, unsigned &Minor) {
1423+
if (M->HasDxilModule()) {
1424+
M->GetDxilModule().GetShaderModel()->GetDxilVersion(Major, Minor);
1425+
return true;
1426+
}
1427+
// No module, try metadata.
1428+
return DxilMDHelper::LoadDxilVersion(M, Major, Minor);
1429+
}
1430+
14181431
} // namespace dxilutil
14191432
} // namespace hlsl

lib/DxilValidation/DxilValidation.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,9 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx,
21932193
return true;
21942194

21952195
if (Ty->isVectorTy()) {
2196+
if (Ty->getVectorNumElements() > 1 &&
2197+
ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
2198+
return true;
21962199
ValCtx.EmitTypeError(Ty, ValidationRule::TypesNoVector);
21972200
return false;
21982201
}
@@ -2669,6 +2672,23 @@ static bool IsLLVMInstructionAllowedForLib(Instruction &I,
26692672
}
26702673
}
26712674

2675+
// Shader model specific checks for valid LLVM instructions.
2676+
// Currently only checks for pre 6.9 usage of vector operations.
2677+
// Returns false if shader model is pre 6.9 and I represents a vector
2678+
// operation. Returns true otherwise.
2679+
static bool IsLLVMInstructionAllowedForShaderModel(Instruction &I,
2680+
ValidationContext &ValCtx) {
2681+
if (ValCtx.DxilMod.GetShaderModel()->IsSM69Plus())
2682+
return true;
2683+
unsigned OpCode = I.getOpcode();
2684+
if (OpCode == Instruction::InsertElement ||
2685+
OpCode == Instruction::ExtractElement ||
2686+
OpCode == Instruction::ShuffleVector)
2687+
return false;
2688+
2689+
return true;
2690+
}
2691+
26722692
static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
26732693
bool SupportsMinPrecision =
26742694
ValCtx.DxilMod.GetGlobalFlags() & DXIL::kEnableMinPrecision;
@@ -2691,7 +2711,8 @@ static void ValidateFunctionBody(Function *F, ValidationContext &ValCtx) {
26912711
}
26922712

26932713
// Instructions must be allowed.
2694-
if (!IsLLVMInstructionAllowed(I)) {
2714+
if (!IsLLVMInstructionAllowed(I) ||
2715+
!IsLLVMInstructionAllowedForShaderModel(I, ValCtx)) {
26952716
if (!IsLLVMInstructionAllowedForLib(I, ValCtx)) {
26962717
ValCtx.EmitInstrError(&I, ValidationRule::InstrAllowed);
26972718
continue;

lib/HLSL/DxilLinker.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,12 @@ void DxilLinkJob::RunPreparePass(Module &M) {
12551255
// For static global handle.
12561256
PM.add(createLowerStaticGlobalIntoAlloca());
12571257

1258+
// Change dynamic indexing vector to array where vectors aren't
1259+
// supported, but might be there from the initial compile.
1260+
if (!pSM->IsSM69Plus())
1261+
PM.add(
1262+
createDynamicIndexingVectorToArrayPass(false /* ReplaceAllVector */));
1263+
12581264
// Remove MultiDimArray from function call arg.
12591265
PM.add(createMultiDimArrayToOneDimArrayPass());
12601266

lib/HLSL/HLMatrixBitcastLowerPass.cpp

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,17 @@ Type *TryLowerMatTy(Type *Ty) {
7676
}
7777

7878
class MatrixBitcastLowerPass : public FunctionPass {
79+
bool SupportsVectors = false;
7980

8081
public:
8182
static char ID; // Pass identification, replacement for typeid
8283
explicit MatrixBitcastLowerPass() : FunctionPass(ID) {}
8384

8485
StringRef getPassName() const override { return "Matrix Bitcast lower"; }
8586
bool runOnFunction(Function &F) override {
87+
DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
88+
SupportsVectors = DM.GetShaderModel()->IsSM69Plus();
89+
8690
bool bUpdated = false;
8791
std::unordered_set<BitCastInst *> matCastSet;
8892
for (auto blkIt = F.begin(); blkIt != F.end(); ++blkIt) {
@@ -100,7 +104,6 @@ class MatrixBitcastLowerPass : public FunctionPass {
100104
}
101105
}
102106

103-
DxilModule &DM = F.getParent()->GetOrCreateDxilModule();
104107
// Remove bitcast which has CallInst user.
105108
if (DM.GetShaderModel()->IsLib()) {
106109
for (auto it = matCastSet.begin(); it != matCastSet.end();) {
@@ -185,18 +188,19 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
185188
User *U = *(it++);
186189
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
187190
Type *EltTy = GEP->getType()->getPointerElementType();
188-
if (HLMatrixType::isa(EltTy)) {
191+
if (HLMatrixType MatTy = HLMatrixType::dyn_cast(EltTy)) {
189192
// Change gep matrixArray, 0, index
190193
// into
191194
// gep oneDimArray, 0, index * matSize
192195
IRBuilder<> Builder(GEP);
193196
SmallVector<Value *, 2> idxList(GEP->idx_begin(), GEP->idx_end());
194197
DXASSERT(idxList.size() == 2,
195198
"else not one dim matrix array index to matrix");
196-
197-
HLMatrixType MatTy = HLMatrixType::cast(EltTy);
198-
Value *matSize = Builder.getInt32(MatTy.getNumElements());
199-
idxList.back() = Builder.CreateMul(idxList.back(), matSize);
199+
unsigned NumElts = MatTy.getNumElements();
200+
if (!SupportsVectors || NumElts == 1) {
201+
Value *MatSize = Builder.getInt32(NumElts);
202+
idxList.back() = Builder.CreateMul(idxList.back(), MatSize);
203+
}
200204
Value *NewGEP = Builder.CreateGEP(A, idxList);
201205
lowerMatrix(GEP, NewGEP);
202206
DXASSERT(GEP->user_empty(), "else lower matrix fail");
@@ -211,13 +215,23 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
211215
} else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
212216
if (VectorType *Ty = dyn_cast<VectorType>(LI->getType())) {
213217
IRBuilder<> Builder(LI);
214-
Value *zeroIdx = Builder.getInt32(0);
215-
unsigned vecSize = Ty->getNumElements();
216-
Value *NewVec = UndefValue::get(LI->getType());
217-
for (unsigned i = 0; i < vecSize; i++) {
218-
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
219-
Value *Elt = Builder.CreateLoad(GEP);
220-
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
218+
Value *NewVec = nullptr;
219+
unsigned VecSize = Ty->getVectorNumElements();
220+
if (SupportsVectors && VecSize > 1) {
221+
// Create a replacement load using the vector pointer.
222+
Instruction *NewLd = LI->clone();
223+
unsigned VecIdx = NewLd->getNumOperands() - 1;
224+
NewLd->setOperand(VecIdx, A);
225+
Builder.Insert(NewLd);
226+
NewVec = NewLd;
227+
} else {
228+
Value *zeroIdx = Builder.getInt32(0);
229+
NewVec = UndefValue::get(LI->getType());
230+
for (unsigned i = 0; i < VecSize; i++) {
231+
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
232+
Value *Elt = Builder.CreateLoad(GEP);
233+
NewVec = Builder.CreateInsertElement(NewVec, Elt, i);
234+
}
221235
}
222236
LI->replaceAllUsesWith(NewVec);
223237
LI->eraseFromParent();
@@ -228,12 +242,20 @@ void MatrixBitcastLowerPass::lowerMatrix(Instruction *M, Value *A) {
228242
Value *V = ST->getValueOperand();
229243
if (VectorType *Ty = dyn_cast<VectorType>(V->getType())) {
230244
IRBuilder<> Builder(LI);
231-
Value *zeroIdx = Builder.getInt32(0);
232-
unsigned vecSize = Ty->getNumElements();
233-
for (unsigned i = 0; i < vecSize; i++) {
234-
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
235-
Value *Elt = Builder.CreateExtractElement(V, i);
236-
Builder.CreateStore(Elt, GEP);
245+
if (SupportsVectors && Ty->getVectorNumElements() > 1) {
246+
// Create a replacement store using the vector pointer.
247+
Instruction *NewSt = ST->clone();
248+
unsigned VecIdx = NewSt->getNumOperands() - 1;
249+
NewSt->setOperand(VecIdx, A);
250+
Builder.Insert(NewSt);
251+
} else {
252+
Value *zeroIdx = Builder.getInt32(0);
253+
unsigned vecSize = Ty->getNumElements();
254+
for (unsigned i = 0; i < vecSize; i++) {
255+
Value *GEP = CreateEltGEP(A, i, zeroIdx, Builder);
256+
Value *Elt = Builder.CreateExtractElement(V, i);
257+
Builder.CreateStore(Elt, GEP);
258+
}
237259
}
238260
ST->eraseFromParent();
239261
} else {

lib/HLSL/HLModule.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,9 @@ MDTuple *HLModule::EmitHLResources() {
604604

605605
void HLModule::LoadHLResources(const llvm::MDOperand &MDO) {
606606
const llvm::MDTuple *pSRVs, *pUAVs, *pCBuffers, *pSamplers;
607+
// No resources. Nothing to do.
608+
if (MDO.get() == nullptr)
609+
return;
607610
m_pMDHelper->GetDxilResources(MDO, pSRVs, pUAVs, pCBuffers, pSamplers);
608611

609612
// Load SRV records.

lib/Transforms/Scalar/LowerTypePasses.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "dxc/DXIL/DxilConstants.h"
11+
#include "dxc/DXIL/DxilModule.h"
1112
#include "dxc/DXIL/DxilOperations.h"
1213
#include "dxc/DXIL/DxilUtil.h"
1314
#include "dxc/HLSL/HLModule.h"
@@ -180,10 +181,12 @@ bool LowerTypePass::runOnModule(Module &M) {
180181
namespace {
181182
class DynamicIndexingVectorToArray : public LowerTypePass {
182183
bool ReplaceAllVectors;
184+
bool SupportsVectors;
183185

184186
public:
185187
explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
186-
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
188+
: LowerTypePass(ID), ReplaceAllVectors(ReplaceAll),
189+
SupportsVectors(false) {}
187190
static char ID; // Pass identification, replacement for typeid
188191
void applyOptions(PassOptions O) override;
189192
void dumpConfig(raw_ostream &OS) override;
@@ -194,6 +197,7 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
194197
Type *lowerType(Type *Ty) override;
195198
Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
196199
StringRef getGlobalPrefix() override { return ".v"; }
200+
void initialize(Module &M) override;
197201

198202
private:
199203
bool HasVectorDynamicIndexing(Value *V);
@@ -207,6 +211,18 @@ class DynamicIndexingVectorToArray : public LowerTypePass {
207211
void ReplaceAddrSpaceCast(ConstantExpr *CE, Value *A, IRBuilder<> &Builder);
208212
};
209213

214+
void DynamicIndexingVectorToArray::initialize(Module &M) {
215+
// Set vector support according to available Dxil version.
216+
// Use HLModule or metadata for version info.
217+
// Otherwise retrieve from dxil module or metadata.
218+
unsigned Major = 0, Minor = 0;
219+
if (M.HasHLModule())
220+
M.GetHLModule().GetShaderModel()->GetDxilVersion(Major, Minor);
221+
else
222+
dxilutil::LoadDxilVersion(&M, Major, Minor);
223+
SupportsVectors = (Major == 1 && Minor >= 9);
224+
}
225+
210226
void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
211227
GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
212228
ReplaceAllVectors);
@@ -306,9 +322,21 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
306322
}
307323

308324
bool DynamicIndexingVectorToArray::needToLower(Value *V) {
325+
bool MustReplaceVector = ReplaceAllVectors;
309326
Type *Ty = V->getType()->getPointerElementType();
310-
if (dyn_cast<VectorType>(Ty)) {
311-
if (isa<GlobalVariable>(V) || ReplaceAllVectors) {
327+
328+
if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
329+
// Array must be replaced even without dynamic indexing to remove vector
330+
// type in dxil.
331+
MustReplaceVector = true;
332+
Ty = dxilutil::GetArrayEltTy(AT);
333+
}
334+
335+
if (isa<VectorType>(Ty)) {
336+
// Only needed for 2+ vectors where native vectors unsupported.
337+
if (SupportsVectors && Ty->getVectorNumElements() > 1)
338+
return false;
339+
if (isa<GlobalVariable>(V) || MustReplaceVector) {
312340
return true;
313341
}
314342
// Don't lower local vector which only static indexing.
@@ -319,12 +347,6 @@ bool DynamicIndexingVectorToArray::needToLower(Value *V) {
319347
ReplaceStaticIndexingOnVector(V);
320348
return false;
321349
}
322-
} else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
323-
// Array must be replaced even without dynamic indexing to remove vector
324-
// type in dxil.
325-
// TODO: optimize static array index in later pass.
326-
Type *EltTy = dxilutil::GetArrayEltTy(AT);
327-
return isa<VectorType>(EltTy);
328350
}
329351
return false;
330352
}

0 commit comments

Comments
 (0)