Skip to content

Commit 4722eee

Browse files
committed
convert raw buffer stores and loads of doubles and i64s
1 parent 0494f93 commit 4722eee

File tree

9 files changed

+763
-109
lines changed

9 files changed

+763
-109
lines changed

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 133 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,23 @@ static bool isIntrinsicExpansion(Function &F) {
7171
case Intrinsic::vector_reduce_add:
7272
case Intrinsic::vector_reduce_fadd:
7373
return true;
74+
case Intrinsic::dx_resource_load_rawbuffer:
75+
if (F.getParent()->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
76+
return false;
77+
// fallthrough to check if double or i64
78+
LLVM_FALLTHROUGH;
7479
case Intrinsic::dx_resource_load_typedbuffer: {
7580
// We need to handle i64, doubles, and vectors of them.
7681
Type *ScalarTy =
7782
F.getReturnType()->getStructElementType(0)->getScalarType();
7883
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
7984
}
85+
case Intrinsic::dx_resource_store_rawbuffer: {
86+
if (F.getParent()->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
87+
return false;
88+
Type *ScalarTy = F.getFunctionType()->getParamType(3)->getScalarType();
89+
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
90+
}
8091
case Intrinsic::dx_resource_store_typedbuffer: {
8192
// We need to handle i64 and doubles and vectors of i64 and doubles.
8293
Type *ScalarTy = F.getFunctionType()->getParamType(2)->getScalarType();
@@ -544,63 +555,81 @@ static Value *expandRadiansIntrinsic(CallInst *Orig) {
544555
return Builder.CreateFMul(X, PiOver180);
545556
}
546557

547-
static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
558+
static bool expandBufferLoadIntrinsic(CallInst *Orig, bool IsRaw) {
548559
IRBuilder<> Builder(Orig);
549560

550561
Type *BufferTy = Orig->getType()->getStructElementType(0);
551562
Type *ScalarTy = BufferTy->getScalarType();
552563
bool IsDouble = ScalarTy->isDoubleTy();
553564
assert(IsDouble || ScalarTy->isIntegerTy(64) &&
554565
"Only expand double or int64 scalars or vectors");
566+
bool IsVector = isa<FixedVectorType>(BufferTy);
555567

556568
unsigned ExtractNum = 2;
557569
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
558-
assert(VT->getNumElements() == 2 &&
559-
"TypedBufferLoad vector must be size 2");
560-
ExtractNum = 4;
570+
if (!IsRaw)
571+
assert(VT->getNumElements() == 2 &&
572+
"TypedBufferLoad vector must be size 2");
573+
ExtractNum = 2 * VT->getNumElements();
561574
}
562575

563-
Type *Ty = VectorType::get(Builder.getInt32Ty(), ExtractNum, false);
564-
565-
Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
566-
CallInst *Load =
567-
Builder.CreateIntrinsic(LoadType, Intrinsic::dx_resource_load_typedbuffer,
568-
{Orig->getOperand(0), Orig->getOperand(1)});
569-
570-
// extract the buffer load's result
571-
Value *Extract = Builder.CreateExtractValue(Load, {0});
572-
573-
SmallVector<Value *> ExtractElements;
574-
for (unsigned I = 0; I < ExtractNum; ++I)
575-
ExtractElements.push_back(
576-
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
577-
578-
// combine into double(s) or int64(s)
576+
SmallVector<Value *, 2> Loads;
579577
Value *Result = PoisonValue::get(BufferTy);
580-
for (unsigned I = 0; I < ExtractNum; I += 2) {
581-
Value *Combined = nullptr;
582-
if (IsDouble)
583-
// For doubles, use dx_asdouble intrinsic
584-
Combined =
585-
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
586-
{ExtractElements[I], ExtractElements[I + 1]});
587-
else {
588-
// For int64, manually combine two int32s
589-
// First, zero-extend both values to i64
590-
Value *Lo = Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
591-
Value *Hi =
592-
Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
593-
// Shift the high bits left by 32 bits
594-
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
595-
// OR the high and low bits together
596-
Combined = Builder.CreateOr(Lo, ShiftedHi);
578+
unsigned Base = 0;
579+
while (ExtractNum > 0) {
580+
unsigned LoadNum = std::min(ExtractNum, 4u);
581+
Type *Ty = VectorType::get(Builder.getInt32Ty(), LoadNum, false);
582+
583+
Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
584+
Intrinsic::ID LoadIntrinsic = Intrinsic::dx_resource_load_typedbuffer;
585+
SmallVector<Value *, 3> Args = {Orig->getOperand(0), Orig->getOperand(1)};
586+
if (IsRaw) {
587+
LoadIntrinsic = Intrinsic::dx_resource_load_rawbuffer;
588+
Value *Tmp = Builder.getInt32(4 * Base * 2);
589+
Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
597590
}
598591

599-
if (ExtractNum == 4)
600-
Result = Builder.CreateInsertElement(Result, Combined,
601-
Builder.getInt32(I / 2));
602-
else
603-
Result = Combined;
592+
CallInst *Load = Builder.CreateIntrinsic(LoadType, LoadIntrinsic, Args);
593+
Loads.push_back(Load);
594+
595+
// extract the buffer load's result
596+
Value *Extract = Builder.CreateExtractValue(Load, {0});
597+
598+
SmallVector<Value *> ExtractElements;
599+
for (unsigned I = 0; I < LoadNum; ++I)
600+
ExtractElements.push_back(
601+
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
602+
603+
// combine into double(s) or int64(s)
604+
for (unsigned I = 0; I < LoadNum; I += 2) {
605+
Value *Combined = nullptr;
606+
if (IsDouble)
607+
// For doubles, use dx_asdouble intrinsic
608+
Combined = Builder.CreateIntrinsic(
609+
Builder.getDoubleTy(), Intrinsic::dx_asdouble,
610+
{ExtractElements[I], ExtractElements[I + 1]});
611+
else {
612+
// For int64, manually combine two int32s
613+
// First, zero-extend both values to i64
614+
Value *Lo =
615+
Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
616+
Value *Hi =
617+
Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
618+
// Shift the high bits left by 32 bits
619+
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
620+
// OR the high and low bits together
621+
Combined = Builder.CreateOr(Lo, ShiftedHi);
622+
}
623+
624+
if (IsVector)
625+
Result = Builder.CreateInsertElement(Result, Combined,
626+
Builder.getInt32((I / 2) + Base));
627+
else
628+
Result = Combined;
629+
}
630+
631+
ExtractNum -= LoadNum;
632+
Base += LoadNum / 2;
604633
}
605634

606635
Value *CheckBit = nullptr;
@@ -620,8 +649,12 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
620649
} else {
621650
// Use of the check bit
622651
assert(Indices[0] == 1 && "Unexpected type for typedbufferload");
623-
if (!CheckBit)
624-
CheckBit = Builder.CreateExtractValue(Load, {1});
652+
if (!CheckBit) {
653+
SmallVector<Value *, 2> CheckBits;
654+
for (Value *L : Loads)
655+
CheckBits.push_back(Builder.CreateExtractValue(L, {1}));
656+
CheckBit = Builder.CreateAnd(CheckBits);
657+
}
625658
EVI->replaceAllUsesWith(CheckBit);
626659
}
627660
EVI->eraseFromParent();
@@ -630,46 +663,52 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
630663
return true;
631664
}
632665

633-
static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
666+
static bool expandBufferStoreIntrinsic(CallInst *Orig, bool IsRaw) {
634667
IRBuilder<> Builder(Orig);
635668

636-
Type *BufferTy = Orig->getFunctionType()->getParamType(2);
669+
Type *BufferTy = Orig->getFunctionType()->getParamType(IsRaw ? 3 : 2);
637670
Type *ScalarTy = BufferTy->getScalarType();
638671
bool IsDouble = ScalarTy->isDoubleTy();
639672
assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
640673
"Only expand double or int64 scalars or vectors");
641674

642675
// Determine if we're dealing with a vector or scalar
643676
bool IsVector = isa<FixedVectorType>(BufferTy);
644-
if (IsVector) {
645-
assert(cast<FixedVectorType>(BufferTy)->getNumElements() == 2 &&
646-
"TypedBufferStore vector must be size 2");
677+
unsigned ExtractNum = 2;
678+
unsigned VecLen = 0;
679+
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
680+
if (!IsRaw)
681+
assert(VT->getNumElements() == 2 &&
682+
"TypedBufferStore vector must be size 2");
683+
VecLen = VT->getNumElements();
684+
ExtractNum = VecLen * 2;
647685
}
648686

649687
// Create the appropriate vector type for the result
650688
Type *Int32Ty = Builder.getInt32Ty();
651-
Type *ResultTy = VectorType::get(Int32Ty, IsVector ? 4 : 2, false);
689+
Type *ResultTy = VectorType::get(Int32Ty, ExtractNum, false);
652690
Value *Val = PoisonValue::get(ResultTy);
653691

654692
Type *SplitElementTy = Int32Ty;
655693
if (IsVector)
656-
SplitElementTy = VectorType::get(SplitElementTy, 2, false);
694+
SplitElementTy = VectorType::get(SplitElementTy, VecLen, false);
657695

658696
Value *LowBits = nullptr;
659697
Value *HighBits = nullptr;
660698
// Split the 64-bit values into 32-bit components
661699
if (IsDouble) {
662700
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
663701
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
664-
{Orig->getOperand(2)});
702+
{Orig->getOperand(IsRaw ? 3 : 2)});
665703
LowBits = Builder.CreateExtractValue(Split, 0);
666704
HighBits = Builder.CreateExtractValue(Split, 1);
667705
} else {
668706
// Handle int64 type(s)
669-
Value *InputVal = Orig->getOperand(2);
707+
Value *InputVal = Orig->getOperand(IsRaw ? 3 : 2);
670708
Constant *ShiftAmt = Builder.getInt64(32);
671709
if (IsVector)
672-
ShiftAmt = ConstantVector::getSplat(ElementCount::getFixed(2), ShiftAmt);
710+
ShiftAmt =
711+
ConstantVector::getSplat(ElementCount::getFixed(VecLen), ShiftAmt);
673712

674713
// Split into low and high 32-bit parts
675714
LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
@@ -678,17 +717,42 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
678717
}
679718

680719
if (IsVector) {
681-
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
720+
SmallVector<int, 8> Mask;
721+
for (unsigned I = 0; I < VecLen; ++I) {
722+
Mask.push_back(I);
723+
Mask.push_back(I + VecLen);
724+
}
725+
Val = Builder.CreateShuffleVector(LowBits, HighBits, Mask);
682726
} else {
683727
Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
684728
Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
685729
}
686730

687-
// Create the final intrinsic call
688-
Builder.CreateIntrinsic(Builder.getVoidTy(),
689-
Intrinsic::dx_resource_store_typedbuffer,
690-
{Orig->getOperand(0), Orig->getOperand(1), Val});
731+
unsigned Base = 0;
732+
while (ExtractNum > 0) {
733+
unsigned StoreNum = std::min(ExtractNum, 4u);
734+
735+
Intrinsic::ID StoreIntrinsic = Intrinsic::dx_resource_store_typedbuffer;
736+
SmallVector<Value *, 4> Args = {Orig->getOperand(0), Orig->getOperand(1)};
737+
if (IsRaw) {
738+
StoreIntrinsic = Intrinsic::dx_resource_store_rawbuffer;
739+
Value *Tmp = Builder.getInt32(4 * Base);
740+
Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
741+
}
742+
743+
SmallVector<int, 4> Mask;
744+
for (unsigned I = 0; I < StoreNum; ++I) {
745+
Mask.push_back(Base + I);
746+
}
747+
Value *SubVal = Builder.CreateShuffleVector(Val, Mask);
748+
749+
Args.push_back(SubVal);
750+
// Create the final intrinsic call
751+
Builder.CreateIntrinsic(Builder.getVoidTy(), StoreIntrinsic, Args);
691752

753+
ExtractNum -= StoreNum;
754+
Base += StoreNum;
755+
}
692756
Orig->eraseFromParent();
693757
return true;
694758
}
@@ -821,12 +885,20 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
821885
case Intrinsic::dx_radians:
822886
Result = expandRadiansIntrinsic(Orig);
823887
break;
888+
case Intrinsic::dx_resource_load_rawbuffer:
889+
if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ true))
890+
return true;
891+
break;
892+
case Intrinsic::dx_resource_store_rawbuffer:
893+
if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ true))
894+
return true;
895+
break;
824896
case Intrinsic::dx_resource_load_typedbuffer:
825-
if (expandTypedBufferLoadIntrinsic(Orig))
897+
if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ false))
826898
return true;
827899
break;
828900
case Intrinsic::dx_resource_store_typedbuffer:
829-
if (expandTypedBufferStoreIntrinsic(Orig))
901+
if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ false))
830902
return true;
831903
break;
832904
case Intrinsic::usub_sat:

llvm/test/CodeGen/DirectX/BufferStoreDouble.ll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ define void @storef64(double %0) {
1616
; CHECK: [[Hi:%.*]] = extractvalue { i32, i32 } [[SD]], 1
1717
; CHECK: [[Vec1:%.*]] = insertelement <2 x i32> poison, i32 [[Lo]], i32 0
1818
; CHECK: [[Vec2:%.*]] = insertelement <2 x i32> [[Vec1]], i32 [[Hi]], i32 1
19+
; this shufflevector is unnecessary but generated to avoid specalization
20+
; CHECK: [[Vec3:%.*]] = shufflevector <2 x i32> [[Vec2]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
1921
; CHECK: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_f64_1_0_0t.v2i32(
20-
; CHECK-SAME: target("dx.TypedBuffer", double, 1, 0, 0) [[B]], i32 0, <2 x i32> [[Vec2]])
22+
; CHECK-SAME: target("dx.TypedBuffer", double, 1, 0, 0) [[B]], i32 0, <2 x i32> [[Vec3]])
2123
call void @llvm.dx.resource.store.typedbuffer(
2224
target("dx.TypedBuffer", double, 1, 0, 0) %buffer, i32 0,
2325
double %0)
@@ -38,8 +40,10 @@ define void @storev2f64(<2 x double> %0) {
3840
; CHECK: [[Lo:%.*]] = extractvalue { <2 x i32>, <2 x i32> } [[SD]], 0
3941
; CHECK: [[Hi:%.*]] = extractvalue { <2 x i32>, <2 x i32> } [[SD]], 1
4042
; CHECK: [[Vec:%.*]] = shufflevector <2 x i32> [[Lo]], <2 x i32> [[Hi]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
43+
; this shufflevector is unnecessary but generated to avoid specalization
44+
; CHECK: [[Vec2:%.*]] = shufflevector <4 x i32> [[Vec]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
4145
; CHECK: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2f64_1_0_0t.v4i32(
42-
; CHECK-SAME: target("dx.TypedBuffer", <2 x double>, 1, 0, 0) [[B]], i32 0, <4 x i32> [[Vec]])
46+
; CHECK-SAME: target("dx.TypedBuffer", <2 x double>, 1, 0, 0) [[B]], i32 0, <4 x i32> [[Vec2]])
4347
call void @llvm.dx.resource.store.typedbuffer(
4448
target("dx.TypedBuffer", <2 x double>, 1, 0, 0) %buffer, i32 0,
4549
<2 x double> %0)

llvm/test/CodeGen/DirectX/BufferStoreInt64.ll

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ define void @storei64(i64 %0) {
1212
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[TMP3]] to i32
1313
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP2]], i32 0
1414
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP4]], i32 1
15-
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP6]])
15+
; the shufflevector is unnecessary but generated to avoid too much specalization
16+
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
17+
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP7]])
1618
; CHECK-NEXT: ret void
1719
;
1820
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
@@ -29,7 +31,9 @@ define void @storev2i64(<2 x i64> %0) {
2931
; CHECK-NEXT: [[TMP3:%.*]] = lshr <2 x i64> [[TMP0]], splat (i64 32)
3032
; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP3]] to <2 x i32>
3133
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP4]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
32-
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP13]])
34+
; the shufflevector is unnecessary but generated to avoid too much specalization
35+
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
36+
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP14]])
3337
; CHECK-NEXT: ret void
3438
;
3539
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)

llvm/test/CodeGen/DirectX/RawBufferLoad-error64.ll

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)