@@ -51,6 +51,7 @@ ExtensionLowering::Strategy ExtensionLowering::GetStrategy(StringRef strategy) {
51
51
case ' p' : return Strategy::Pack;
52
52
case ' m' : return Strategy::Resource;
53
53
case ' d' : return Strategy::Dxil;
54
+ case ' c' : return Strategy::Custom;
54
55
default : break ;
55
56
}
56
57
return Strategy::Unknown;
@@ -63,6 +64,7 @@ llvm::StringRef ExtensionLowering::GetStrategyName(Strategy strategy) {
63
64
case Strategy::Pack: return " p" ;
64
65
case Strategy::Resource: return " m" ; // m for resource method
65
66
case Strategy::Dxil: return " d" ;
67
+ case Strategy::Custom: return " c" ;
66
68
default : break ;
67
69
}
68
70
return " ?" ;
@@ -91,6 +93,7 @@ llvm::Value *ExtensionLowering::Translate(llvm::CallInst *CI) {
91
93
case Strategy::Pack: return Pack (CI);
92
94
case Strategy::Resource: return Resource (CI);
93
95
case Strategy::Dxil: return Dxil (CI);
96
+ case Strategy::Custom: return Custom (CI);
94
97
default : break ;
95
98
}
96
99
return Unknown (CI);
@@ -373,6 +376,51 @@ Value *ExtensionLowering::Replicate(CallInst *CI) {
373
376
return replicate.Generate ();
374
377
}
375
378
379
+ // /////////////////////////////////////////////////////////////////////////////
380
+ // Helper functions
381
+ static VectorType* ConvertStructTypeToVectorType (Type* structTy) {
382
+ assert (structTy->isStructTy ());
383
+ return VectorType::get (structTy->getStructElementType (0 ), structTy->getStructNumElements ());
384
+ }
385
+
386
+ static Value* PackStructIntoVector (IRBuilder<>& builder, Value* strukt) {
387
+ Type* vecTy = ConvertStructTypeToVectorType (strukt->getType ());
388
+ Value* packed = UndefValue::get (vecTy);
389
+
390
+ unsigned numElements = vecTy->getVectorNumElements ();
391
+ for (unsigned i = 0 ; i < numElements; ++i) {
392
+ Value* element = builder.CreateExtractValue (strukt, i);
393
+ packed = builder.CreateInsertElement (packed, element, i);
394
+ }
395
+
396
+ return packed;
397
+ }
398
+
399
+ static StructType* ConvertVectorTypeToStructType (Type* vecTy) {
400
+ assert (vecTy->isVectorTy ());
401
+ Type* elementTy = vecTy->getVectorElementType ();
402
+ unsigned numElements = vecTy->getVectorNumElements ();
403
+ SmallVector<Type*, 4 > elements;
404
+ for (unsigned i = 0 ; i < numElements; ++i)
405
+ elements.push_back (elementTy);
406
+
407
+ return StructType::get (vecTy->getContext (), elements);
408
+ }
409
+
410
+
411
+ static Value* PackVectorIntoStruct (IRBuilder<>& builder, Value* vec) {
412
+ StructType* structTy = ConvertVectorTypeToStructType (vec->getType ());
413
+ Value* packed = UndefValue::get (structTy);
414
+
415
+ unsigned numElements = structTy->getStructNumElements ();
416
+ for (unsigned i = 0 ; i < numElements; ++i) {
417
+ Value* element = builder.CreateExtractElement (vec, i);
418
+ packed = builder.CreateInsertValue (packed, element, { i });
419
+ }
420
+
421
+ return packed;
422
+ }
423
+
376
424
// /////////////////////////////////////////////////////////////////////////////
377
425
// Packed Lowering.
378
426
class PackCall {
@@ -389,17 +437,6 @@ class PackCall {
389
437
Value *result = CreateCall (args);
390
438
return UnpackResult (result);
391
439
}
392
-
393
- static StructType *ConvertVectorTypeToStructType (Type *vecTy) {
394
- assert (vecTy->isVectorTy ());
395
- Type *elementTy = vecTy->getVectorElementType ();
396
- unsigned numElements = vecTy->getVectorNumElements ();
397
- SmallVector<Type *, 4 > elements;
398
- for (unsigned i = 0 ; i < numElements; ++i)
399
- elements.push_back (elementTy);
400
-
401
- return StructType::get (vecTy->getContext (), elements);
402
- }
403
440
404
441
private:
405
442
CallInst *m_CI;
@@ -425,37 +462,6 @@ class PackCall {
425
462
}
426
463
return result;
427
464
}
428
-
429
- static VectorType *ConvertStructTypeToVectorType (Type *structTy) {
430
- assert (structTy->isStructTy ());
431
- return VectorType::get (structTy->getStructElementType (0 ), structTy->getStructNumElements ());
432
- }
433
-
434
- static Value *PackVectorIntoStruct (IRBuilder<> &builder, Value *vec) {
435
- StructType *structTy = ConvertVectorTypeToStructType (vec->getType ());
436
- Value *packed = UndefValue::get (structTy);
437
-
438
- unsigned numElements = structTy->getStructNumElements ();
439
- for (unsigned i = 0 ; i < numElements; ++i) {
440
- Value *element = builder.CreateExtractElement (vec, i);
441
- packed = builder.CreateInsertValue (packed, element, { i });
442
- }
443
-
444
- return packed;
445
- }
446
-
447
- static Value *PackStructIntoVector (IRBuilder<> &builder, Value *strukt) {
448
- Type *vecTy = ConvertStructTypeToVectorType (strukt->getType ());
449
- Value *packed = UndefValue::get (vecTy);
450
-
451
- unsigned numElements = vecTy->getVectorNumElements ();
452
- for (unsigned i = 0 ; i < numElements; ++i) {
453
- Value *element = builder.CreateExtractValue (strukt, i);
454
- packed = builder.CreateInsertElement (packed, element, i);
455
- }
456
-
457
- return packed;
458
- }
459
465
};
460
466
461
467
class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
@@ -468,7 +474,7 @@ class PackedFunctionTypeTranslator : public FunctionTypeTranslator {
468
474
469
475
Type *TranslateIfVector (Type *ty) {
470
476
if (ty->isVectorTy ())
471
- ty = PackCall:: ConvertVectorTypeToStructType (ty);
477
+ ty = ConvertVectorTypeToStructType (ty);
472
478
return ty;
473
479
}
474
480
};
@@ -713,10 +719,30 @@ Value *ExtensionLowering::Resource(CallInst *CI) {
713
719
// dxil: @MyTextureOp(17, handle, a.x, a.y, undef, c.x, c.y)
714
720
//
715
721
//
716
- class CustomResourceLowering
722
+ class CustomLowering
717
723
{
718
724
public:
719
- CustomResourceLowering (StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
725
+ CustomLowering (StringRef LoweringInfo, CallInst* CI)
726
+ {
727
+ // Parse lowering info json format.
728
+ std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
729
+ ParseLoweringInfo (LoweringInfo, CI->getContext ());
730
+
731
+ // Find the default lowering kind
732
+ std::vector<DxilArgInfo> *pArgInfo = nullptr ;
733
+ if (LoweringInfoMap.count (m_DefaultInfoName))
734
+ {
735
+ pArgInfo = &LoweringInfoMap.at (m_DefaultInfoName);
736
+ }
737
+ else
738
+ {
739
+ ThrowExtensionError (" Unable to find lowering info for custom function" );
740
+ }
741
+ // Don't explode vectors for custom functions
742
+ GenerateLoweredArgs (CI, *pArgInfo);
743
+ }
744
+
745
+ CustomLowering (StringRef LoweringInfo, CallInst *CI, HLResourceLookup &ResourceLookup)
720
746
{
721
747
// Parse lowering info json format.
722
748
std::map<ResourceKindName, std::vector<DxilArgInfo>> LoweringInfoMap =
@@ -732,15 +758,14 @@ class CustomResourceLowering
732
758
std::string Name (pName);
733
759
734
760
// Select lowering info to use based on resource kind.
735
- const char *DefaultInfoName = " default" ;
736
761
std::vector<DxilArgInfo> *pArgInfo = nullptr ;
737
762
if (LoweringInfoMap.count (Name))
738
763
{
739
764
pArgInfo = &LoweringInfoMap.at (Name);
740
765
}
741
- else if (LoweringInfoMap.count (DefaultInfoName ))
766
+ else if (LoweringInfoMap.count (m_DefaultInfoName ))
742
767
{
743
- pArgInfo = &LoweringInfoMap.at (DefaultInfoName );
768
+ pArgInfo = &LoweringInfoMap.at (m_DefaultInfoName );
744
769
}
745
770
else
746
771
{
@@ -775,6 +800,7 @@ class CustomResourceLowering
775
800
{" ?half" , Type::getHalfTy (Ctx)},
776
801
{" ?i8" , Type::getInt8Ty (Ctx)},
777
802
{" ?i16" , Type::getInt16Ty (Ctx)},
803
+ {" ?i1" , Type::getInt1Ty (Ctx)},
778
804
};
779
805
DXASSERT (m_OptionalTypes.empty (), " Init should only be called once" );
780
806
m_OptionalTypes.clear ();
@@ -965,6 +991,13 @@ class CustomResourceLowering
965
991
}
966
992
}
967
993
}
994
+ else
995
+ {
996
+ // If the vector isn't exploded, use structs for DXIL Intrinsics
997
+ if (Arg->getType ()->isVectorTy ()) {
998
+ Arg = PackVectorIntoStruct (builder, Arg);
999
+ }
1000
+ }
968
1001
969
1002
m_LoweredArgs.push_back (Arg);
970
1003
}
@@ -984,27 +1017,28 @@ class CustomResourceLowering
984
1017
985
1018
std::vector<Value *> m_LoweredArgs;
986
1019
SmallVector<OptionalTypeSpec, 5 > m_OptionalTypes;
1020
+ const char * m_DefaultInfoName = " default" ;
987
1021
};
988
1022
989
1023
// Boilerplate to reuse exising logic as much as possible.
990
1024
// We just want to overload GetFunctionType here.
991
- class CustomResourceFunctionTranslator : public FunctionTranslator {
1025
+ class CustomFunctionTranslator : public FunctionTranslator {
992
1026
public:
993
1027
static Function *GetLoweredFunction (
994
- const CustomResourceLowering &CustomLowering,
995
- ResourceFunctionTypeTranslator &typeTranslator,
1028
+ const CustomLowering &CustomLowering,
1029
+ FunctionTypeTranslator &typeTranslator,
996
1030
CallInst *CI,
997
1031
ExtensionLowering &lower
998
1032
)
999
1033
{
1000
- CustomResourceFunctionTranslator T (CustomLowering, typeTranslator, lower);
1034
+ CustomFunctionTranslator T (CustomLowering, typeTranslator, lower);
1001
1035
return T.FunctionTranslator ::GetLoweredFunction (CI);
1002
1036
}
1003
1037
1004
1038
private:
1005
- CustomResourceFunctionTranslator (
1006
- const CustomResourceLowering &CustomLowering,
1007
- ResourceFunctionTypeTranslator &typeTranslator,
1039
+ CustomFunctionTranslator (
1040
+ const CustomLowering &CustomLowering,
1041
+ FunctionTypeTranslator &typeTranslator,
1008
1042
ExtensionLowering &lower
1009
1043
)
1010
1044
: FunctionTranslator(typeTranslator, lower)
@@ -1023,15 +1057,15 @@ class CustomResourceFunctionTranslator : public FunctionTranslator {
1023
1057
}
1024
1058
1025
1059
private:
1026
- const CustomResourceLowering &m_CustomLowering;
1060
+ const CustomLowering &m_CustomLowering;
1027
1061
};
1028
1062
1029
1063
// Boilerplate to reuse exising logic as much as possible.
1030
1064
// We just want to overload Generate here.
1031
1065
class CustomResourceMethodCall : public ResourceMethodCall
1032
1066
{
1033
1067
public:
1034
- CustomResourceMethodCall (CallInst *CI, const CustomResourceLowering &CustomLowering)
1068
+ CustomResourceMethodCall (CallInst *CI, const CustomLowering &CustomLowering)
1035
1069
: ResourceMethodCall(CI)
1036
1070
, m_CustomLowering(CustomLowering)
1037
1071
{}
@@ -1043,14 +1077,14 @@ class CustomResourceMethodCall : public ResourceMethodCall
1043
1077
}
1044
1078
1045
1079
private:
1046
- const CustomResourceLowering &m_CustomLowering;
1080
+ const CustomLowering &m_CustomLowering;
1047
1081
};
1048
1082
1049
1083
// Support custom lowering logic for resource functions.
1050
1084
Value *ExtensionLowering::CustomResource (CallInst *CI) {
1051
- CustomResourceLowering CustomLowering (m_extraStrategyInfo, CI, m_hlResourceLookup);
1085
+ CustomLowering CustomLowering (m_extraStrategyInfo, CI, m_hlResourceLookup);
1052
1086
ResourceFunctionTypeTranslator ResourceTypeTranslator (m_hlslOp);
1053
- Function *ResourceFunction = CustomResourceFunctionTranslator ::GetLoweredFunction (
1087
+ Function *ResourceFunction = CustomFunctionTranslator ::GetLoweredFunction (
1054
1088
CustomLowering,
1055
1089
ResourceTypeTranslator,
1056
1090
CI,
@@ -1064,6 +1098,30 @@ Value *ExtensionLowering::CustomResource(CallInst *CI) {
1064
1098
return Result;
1065
1099
}
1066
1100
1101
+ // Support custom lowering logic for arbitrary functions.
1102
+ Value *ExtensionLowering::Custom (CallInst *CI) {
1103
+ CustomLowering CustomLowering (m_extraStrategyInfo, CI);
1104
+ PackedFunctionTypeTranslator TypeTranslator;
1105
+ Function *CustomFunction = CustomFunctionTranslator::GetLoweredFunction (
1106
+ CustomLowering,
1107
+ TypeTranslator,
1108
+ CI,
1109
+ *this
1110
+ );
1111
+ if (!CustomFunction)
1112
+ return NoTranslation (CI);
1113
+
1114
+ IRBuilder<> builder (CI);
1115
+ Value* result = builder.CreateCall (CustomFunction, CustomLowering.GetLoweredArgs ());
1116
+
1117
+ // Arbitrary functions will expect vectors, not structs
1118
+ if (CustomFunction->getReturnType ()->isStructTy ()) {
1119
+ return PackStructIntoVector (builder, result);
1120
+ }
1121
+
1122
+ return result;
1123
+ }
1124
+
1067
1125
// /////////////////////////////////////////////////////////////////////////////
1068
1126
// Dxil Lowering.
1069
1127
0 commit comments