Skip to content

Commit ce8268a

Browse files
authored
Custom extensions to enable lowering for non-resource based functions (microsoft#5579)
A downstream consumer requires the use of custom function lowering using a json string as well as methods. For instance, a direct memory load not based on a resource.
1 parent f49bb3c commit ce8268a

File tree

3 files changed

+192
-65
lines changed

3 files changed

+192
-65
lines changed

include/dxc/HLSL/HLOperationLowerExtension.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ namespace hlsl {
4545
Pack, // Convert the vector arguments into structs.
4646
Resource, // Convert return value to resource return and explode vectors.
4747
Dxil, // Convert call to a dxil intrinsic.
48+
Custom, // Custom lowering based on flexible json string.
4849
};
4950

5051
// Create the lowering using the given strategy and custom codegen helper.
@@ -86,5 +87,6 @@ namespace hlsl {
8687
llvm::Value *Resource(llvm::CallInst *CI);
8788
llvm::Value *Dxil(llvm::CallInst *CI);
8889
llvm::Value *CustomResource(llvm::CallInst *CI);
90+
llvm::Value *Custom(llvm::CallInst *CI);
8991
};
9092
}

lib/HLSL/HLOperationLowerExtension.cpp

Lines changed: 118 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
5151
case 'p': return Strategy::Pack;
5252
case 'm': return Strategy::Resource;
5353
case 'd': return Strategy::Dxil;
54+
case 'c': return Strategy::Custom;
5455
default: break;
5556
}
5657
return Strategy::Unknown;
@@ -63,6 +64,7 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
6364
case Strategy::Pack: return "p";
6465
case Strategy::Resource: return "m"; // m for resource method
6566
case Strategy::Dxil: return "d";
67+
case Strategy::Custom: return "c";
6668
default: break;
6769
}
6870
return "?";
@@ -91,6 +93,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
9193
case Strategy::Pack: return Pack(CI);
9294
case Strategy::Resource: return Resource(CI);
9395
case Strategy::Dxil: return Dxil(CI);
96+
case Strategy::Custom: return Custom(CI);
9497
default: break;
9598
}
9699
return Unknown(CI);
@@ -373,6 +376,51 @@ Value *ExtensionLowering::Replicate(CallInst *CI) {
373376
return replicate.Generate();
374377
}
375378

379+
///////////////////////////////////////////////////////////////////////////////
380+
// Helper functions
381+
static VectorType* ConvertStructTypeToVectorType(Type* structTy) {
382+
assert(structTy->isStructTy());
383+
return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
384+
}
385+
386+
static Value* PackStructIntoVector(IRBuilder<>& builder, Value* strukt) {
387+
Type* vecTy = ConvertStructTypeToVectorType(strukt->getType());
388+
Value* packed = UndefValue::get(vecTy);
389+
390+
unsigned numElements = vecTy->getVectorNumElements();
391+
for (unsigned i = 0; i < numElements; ++i) {
392+
Value* element = builder.CreateExtractValue(strukt, i);
393+
packed = builder.CreateInsertElement(packed, element, i);
394+
}
395+
396+
return packed;
397+
}
398+
399+
static StructType* ConvertVectorTypeToStructType(Type* vecTy) {
400+
assert(vecTy->isVectorTy());
401+
Type* elementTy = vecTy->getVectorElementType();
402+
unsigned numElements = vecTy->getVectorNumElements();
403+
SmallVector<Type*, 4> elements;
404+
for (unsigned i = 0; i < numElements; ++i)
405+
elements.push_back(elementTy);
406+
407+
return StructType::get(vecTy->getContext(), elements);
408+
}
409+
410+
411+
static Value* PackVectorIntoStruct(IRBuilder<>& builder, Value* vec) {
412+
StructType* structTy = ConvertVectorTypeToStructType(vec->getType());
413+
Value* packed = UndefValue::get(structTy);
414+
415+
unsigned numElements = structTy->getStructNumElements();
416+
for (unsigned i = 0; i < numElements; ++i) {
417+
Value* element = builder.CreateExtractElement(vec, i);
418+
packed = builder.CreateInsertValue(packed, element, { i });
419+
}
420+
421+
return packed;
422+
}
423+
376424
///////////////////////////////////////////////////////////////////////////////
377425
// Packed Lowering.
378426
class PackCall {
@@ -389,17 +437,6 @@ class PackCall {
389437
Value *result = CreateCall(args);
390438
return UnpackResult(result);
391439
}
392-
393-
static StructType *ConvertVectorTypeToStructType(Type *vecTy) {
394-
assert(vecTy->isVectorTy());
395-
Type *elementTy = vecTy->getVectorElementType();
396-
unsigned numElements = vecTy->getVectorNumElements();
397-
SmallVector<Type *, 4> elements;
398-
for (unsigned i = 0; i < numElements; ++i)
399-
elements.push_back(elementTy);
400-
401-
return StructType::get(vecTy->getContext(), elements);
402-
}
403440

404441
private:
405442
CallInst *m_CI;
@@ -425,37 +462,6 @@ class PackCall {
425462
}
426463
return result;
427464
}
428-
429-
static VectorType *ConvertStructTypeToVectorType(Type *structTy) {
430-
assert(structTy->isStructTy());
431-
return VectorType::get(structTy->getStructElementType(0), structTy->getStructNumElements());
432-
}
433-
434-
static Value *PackVectorIntoStruct(IRBuilder<> &builder, Value *vec) {
435-
StructType *structTy = ConvertVectorTypeToStructType(vec->getType());
436-
Value *packed = UndefValue::get(structTy);
437-
438-
unsigned numElements = structTy->getStructNumElements();
439-
for (unsigned i = 0; i < numElements; ++i) {
440-
Value *element = builder.CreateExtractElement(vec, i);
441-
packed = builder.CreateInsertValue(packed, element, { i });
442-
}
443-
444-
return packed;
445-
}
446-
447-
static Value *PackStructIntoVector(IRBuilder<> &builder, Value *strukt) {
448-
Type *vecTy = ConvertStructTypeToVectorType(strukt->getType());
449-
Value *packed = UndefValue::get(vecTy);
450-
451-
unsigned numElements = vecTy->getVectorNumElements();
452-
for (unsigned i = 0; i < numElements; ++i) {
453-
Value *element = builder.CreateExtractValue(strukt, i);
454-
packed = builder.CreateInsertElement(packed, element, i);
455-
}
456-
457-
return packed;
458-
}
459465
};
460466

461467
class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
@@ -468,7 +474,7 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
468474

469475
Type *TranslateIfVector(Type *ty) {
470476
if (ty->isVectorTy())
471-
ty = PackCall::ConvertVectorTypeToStructType(ty);
477+
ty = ConvertVectorTypeToStructType(ty);
472478
return ty;
473479
}
474480
};
@@ -713,10 +719,30 @@ Value *ExtensionLowering::Resource(CallInst *CI) {
713719
// dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
714720
//
715721
//
716-
class CustomResourceLowering
722+
class CustomLowering
717723
{
718724
public:
719-
CustomResourceLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
725+
CustomLowering(StringRef LoweringInfo, CallInst* CI)
726+
{
727+
// Parse lowering info json format.
728+
std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
729+
ParseLoweringInfo(LoweringInfo, CI->getContext());
730+
731+
// Find the default lowering kind
732+
std::vector<DxilArgInfo> *pArgInfo = nullptr;
733+
if (LoweringInfoMap.count(m_DefaultInfoName))
734+
{
735+
pArgInfo = &LoweringInfoMap.at(m_DefaultInfoName);
736+
}
737+
else
738+
{
739+
ThrowExtensionError("Unable to find lowering info for custom function");
740+
}
741+
// Don't explode vectors for custom functions
742+
GenerateLoweredArgs(CI, *pArgInfo);
743+
}
744+
745+
CustomLowering(StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
720746
{
721747
// Parse lowering info json format.
722748
std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
@@ -732,15 +758,14 @@ class CustomResourceLowering
732758
std::string Name(pName);
733759

734760
// Select lowering info to use based on resource kind.
735-
const char *DefaultInfoName = "default";
736761
std::vector<DxilArgInfo> *pArgInfo = nullptr;
737762
if (LoweringInfoMap.count(Name))
738763
{
739764
pArgInfo = &LoweringInfoMap.at(Name);
740765
}
741-
else if (LoweringInfoMap.count(DefaultInfoName))
766+
else if (LoweringInfoMap.count(m_DefaultInfoName))
742767
{
743-
pArgInfo = &LoweringInfoMap.at(DefaultInfoName);
768+
pArgInfo = &LoweringInfoMap.at(m_DefaultInfoName);
744769
}
745770
else
746771
{
@@ -775,6 +800,7 @@ class CustomResourceLowering
775800
{"?half", Type::getHalfTy(Ctx)},
776801
{"?i8", Type::getInt8Ty(Ctx)},
777802
{"?i16", Type::getInt16Ty(Ctx)},
803+
{"?i1", Type::getInt1Ty(Ctx)},
778804
};
779805
DXASSERT(m_OptionalTypes.empty(), "Init should only be called once");
780806
m_OptionalTypes.clear();
@@ -965,6 +991,13 @@ class CustomResourceLowering
965991
}
966992
}
967993
}
994+
else
995+
{
996+
// If the vector isn't exploded, use structs for DXIL Intrinsics
997+
if (Arg->getType()->isVectorTy()) {
998+
Arg = PackVectorIntoStruct(builder, Arg);
999+
}
1000+
}
9681001

9691002
m_LoweredArgs.push_back(Arg);
9701003
}
@@ -984,27 +1017,28 @@ class CustomResourceLowering
9841017

9851018
std::vector<Value *> m_LoweredArgs;
9861019
SmallVector<OptionalTypeSpec, 5> m_OptionalTypes;
1020+
const char* m_DefaultInfoName = "default";
9871021
};
9881022

9891023
// Boilerplate to reuse exising logic as much as possible.
9901024
// We just want to overload GetFunctionType here.
991-
class CustomResourceFunctionTranslator : public FunctionTranslator {
1025+
class CustomFunctionTranslator : public FunctionTranslator {
9921026
public:
9931027
static Function *GetLoweredFunction(
994-
const CustomResourceLowering &CustomLowering,
995-
ResourceFunctionTypeTranslator &typeTranslator,
1028+
const CustomLowering &CustomLowering,
1029+
FunctionTypeTranslator &typeTranslator,
9961030
CallInst *CI,
9971031
ExtensionLowering &lower
9981032
)
9991033
{
1000-
CustomResourceFunctionTranslator T(CustomLowering, typeTranslator, lower);
1034+
CustomFunctionTranslator T(CustomLowering, typeTranslator, lower);
10011035
return T.FunctionTranslator::GetLoweredFunction(CI);
10021036
}
10031037

10041038
private:
1005-
CustomResourceFunctionTranslator(
1006-
const CustomResourceLowering &CustomLowering,
1007-
ResourceFunctionTypeTranslator &typeTranslator,
1039+
CustomFunctionTranslator(
1040+
const CustomLowering &CustomLowering,
1041+
FunctionTypeTranslator &typeTranslator,
10081042
ExtensionLowering &lower
10091043
)
10101044
: FunctionTranslator(typeTranslator, lower)
@@ -1023,15 +1057,15 @@ class CustomResourceFunctionTranslator : public FunctionTranslator {
10231057
}
10241058

10251059
private:
1026-
const CustomResourceLowering &m_CustomLowering;
1060+
const CustomLowering &m_CustomLowering;
10271061
};
10281062

10291063
// Boilerplate to reuse exising logic as much as possible.
10301064
// We just want to overload Generate here.
10311065
class CustomResourceMethodCall : public ResourceMethodCall
10321066
{
10331067
public:
1034-
CustomResourceMethodCall(CallInst *CI, const CustomResourceLowering &CustomLowering)
1068+
CustomResourceMethodCall(CallInst *CI, const CustomLowering &CustomLowering)
10351069
: ResourceMethodCall(CI)
10361070
, m_CustomLowering(CustomLowering)
10371071
{}
@@ -1043,14 +1077,14 @@ class CustomResourceMethodCall : public ResourceMethodCall
10431077
}
10441078

10451079
private:
1046-
const CustomResourceLowering &m_CustomLowering;
1080+
const CustomLowering &m_CustomLowering;
10471081
};
10481082

10491083
// Support custom lowering logic for resource functions.
10501084
Value *ExtensionLowering::CustomResource(CallInst *CI) {
1051-
CustomResourceLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
1085+
CustomLowering CustomLowering(m_extraStrategyInfo, CI, m_hlResourceLookup);
10521086
ResourceFunctionTypeTranslator ResourceTypeTranslator(m_hlslOp);
1053-
Function *ResourceFunction = CustomResourceFunctionTranslator::GetLoweredFunction(
1087+
Function *ResourceFunction = CustomFunctionTranslator::GetLoweredFunction(
10541088
CustomLowering,
10551089
ResourceTypeTranslator,
10561090
CI,
@@ -1064,6 +1098,30 @@ Value *ExtensionLowering::CustomResource(CallInst *CI) {
10641098
return Result;
10651099
}
10661100

1101+
// Support custom lowering logic for arbitrary functions.
1102+
Value *ExtensionLowering::Custom(CallInst *CI) {
1103+
CustomLowering CustomLowering(m_extraStrategyInfo, CI);
1104+
PackedFunctionTypeTranslator TypeTranslator;
1105+
Function *CustomFunction = CustomFunctionTranslator::GetLoweredFunction(
1106+
CustomLowering,
1107+
TypeTranslator,
1108+
CI,
1109+
*this
1110+
);
1111+
if (!CustomFunction)
1112+
return NoTranslation(CI);
1113+
1114+
IRBuilder<> builder(CI);
1115+
Value* result = builder.CreateCall(CustomFunction, CustomLowering.GetLoweredArgs());
1116+
1117+
// Arbitrary functions will expect vectors, not structs
1118+
if (CustomFunction->getReturnType()->isStructTy()) {
1119+
return PackStructIntoVector(builder, result);
1120+
}
1121+
1122+
return result;
1123+
}
1124+
10671125
///////////////////////////////////////////////////////////////////////////////
10681126
// Dxil Lowering.
10691127

0 commit comments

Comments
 (0)