Skip to content

[DirectX] Add support for Raw Buffer Loads and Stores for scalars and vectors of doubles and i64s in SM6.2 and earlier #146627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 133 additions & 61 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,23 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_fadd:
return true;
case Intrinsic::dx_resource_load_rawbuffer:
if (F.getParent()->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
return false;
// fallthrough to check if double or i64
LLVM_FALLTHROUGH;
case Intrinsic::dx_resource_load_typedbuffer: {
// We need to handle i64, doubles, and vectors of them.
Type *ScalarTy =
F.getReturnType()->getStructElementType(0)->getScalarType();
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
}
case Intrinsic::dx_resource_store_rawbuffer: {
if (F.getParent()->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
return false;
Type *ScalarTy = F.getFunctionType()->getParamType(3)->getScalarType();
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
}
case Intrinsic::dx_resource_store_typedbuffer: {
// We need to handle i64 and doubles and vectors of i64 and doubles.
Type *ScalarTy = F.getFunctionType()->getParamType(2)->getScalarType();
Expand Down Expand Up @@ -544,63 +555,81 @@ static Value *expandRadiansIntrinsic(CallInst *Orig) {
return Builder.CreateFMul(X, PiOver180);
}

static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
static bool expandBufferLoadIntrinsic(CallInst *Orig, bool IsRaw) {
IRBuilder<> Builder(Orig);

Type *BufferTy = Orig->getType()->getStructElementType(0);
Type *ScalarTy = BufferTy->getScalarType();
bool IsDouble = ScalarTy->isDoubleTy();
assert(IsDouble || ScalarTy->isIntegerTy(64) &&
"Only expand double or int64 scalars or vectors");
bool IsVector = isa<FixedVectorType>(BufferTy);

unsigned ExtractNum = 2;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
assert(VT->getNumElements() == 2 &&
"TypedBufferLoad vector must be size 2");
ExtractNum = 4;
if (!IsRaw)
assert(VT->getNumElements() == 2 &&
"TypedBufferLoad vector must be size 2");
ExtractNum = 2 * VT->getNumElements();
}

Type *Ty = VectorType::get(Builder.getInt32Ty(), ExtractNum, false);

Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
CallInst *Load =
Builder.CreateIntrinsic(LoadType, Intrinsic::dx_resource_load_typedbuffer,
{Orig->getOperand(0), Orig->getOperand(1)});

// extract the buffer load's result
Value *Extract = Builder.CreateExtractValue(Load, {0});

SmallVector<Value *> ExtractElements;
for (unsigned I = 0; I < ExtractNum; ++I)
ExtractElements.push_back(
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));

// combine into double(s) or int64(s)
SmallVector<Value *, 2> Loads;
Value *Result = PoisonValue::get(BufferTy);
for (unsigned I = 0; I < ExtractNum; I += 2) {
Value *Combined = nullptr;
if (IsDouble)
// For doubles, use dx_asdouble intrinsic
Combined =
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
{ExtractElements[I], ExtractElements[I + 1]});
else {
// For int64, manually combine two int32s
// First, zero-extend both values to i64
Value *Lo = Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
Value *Hi =
Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
// Shift the high bits left by 32 bits
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
// OR the high and low bits together
Combined = Builder.CreateOr(Lo, ShiftedHi);
unsigned Base = 0;
while (ExtractNum > 0) {
unsigned LoadNum = std::min(ExtractNum, 4u);
Type *Ty = VectorType::get(Builder.getInt32Ty(), LoadNum, false);

Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
Intrinsic::ID LoadIntrinsic = Intrinsic::dx_resource_load_typedbuffer;
SmallVector<Value *, 3> Args = {Orig->getOperand(0), Orig->getOperand(1)};
if (IsRaw) {
LoadIntrinsic = Intrinsic::dx_resource_load_rawbuffer;
Value *Tmp = Builder.getInt32(4 * Base * 2);
Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
}

if (ExtractNum == 4)
Result = Builder.CreateInsertElement(Result, Combined,
Builder.getInt32(I / 2));
else
Result = Combined;
CallInst *Load = Builder.CreateIntrinsic(LoadType, LoadIntrinsic, Args);
Loads.push_back(Load);

// extract the buffer load's result
Value *Extract = Builder.CreateExtractValue(Load, {0});

SmallVector<Value *> ExtractElements;
for (unsigned I = 0; I < LoadNum; ++I)
ExtractElements.push_back(
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));

// combine into double(s) or int64(s)
for (unsigned I = 0; I < LoadNum; I += 2) {
Value *Combined = nullptr;
if (IsDouble)
// For doubles, use dx_asdouble intrinsic
Combined = Builder.CreateIntrinsic(
Builder.getDoubleTy(), Intrinsic::dx_asdouble,
{ExtractElements[I], ExtractElements[I + 1]});
else {
// For int64, manually combine two int32s
// First, zero-extend both values to i64
Value *Lo =
Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
Value *Hi =
Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
// Shift the high bits left by 32 bits
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
// OR the high and low bits together
Combined = Builder.CreateOr(Lo, ShiftedHi);
}

if (IsVector)
Result = Builder.CreateInsertElement(Result, Combined,
Builder.getInt32((I / 2) + Base));
Comment on lines +626 to +628
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be something like ExtractNum > 2 rather than IsVector? For odd-length vectors I'd expect the last value to be scalar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For an odd length vector of length 3, we still need our Result to be a vector, and ExtractNum changes during the loop, so ExtractNum would be 2 when we're loading the 3rd value in a vec3.

else
Result = Combined;
}

ExtractNum -= LoadNum;
Base += LoadNum / 2;
}

Value *CheckBit = nullptr;
Expand All @@ -620,8 +649,12 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
} else {
// Use of the check bit
assert(Indices[0] == 1 && "Unexpected type for typedbufferload");
if (!CheckBit)
CheckBit = Builder.CreateExtractValue(Load, {1});
if (!CheckBit) {
SmallVector<Value *, 2> CheckBits;
for (Value *L : Loads)
CheckBits.push_back(Builder.CreateExtractValue(L, {1}));
CheckBit = Builder.CreateAnd(CheckBits);
}
EVI->replaceAllUsesWith(CheckBit);
}
EVI->eraseFromParent();
Expand All @@ -630,46 +663,52 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
return true;
}

static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
static bool expandBufferStoreIntrinsic(CallInst *Orig, bool IsRaw) {
IRBuilder<> Builder(Orig);

Type *BufferTy = Orig->getFunctionType()->getParamType(2);
Type *BufferTy = Orig->getFunctionType()->getParamType(IsRaw ? 3 : 2);
Type *ScalarTy = BufferTy->getScalarType();
bool IsDouble = ScalarTy->isDoubleTy();
assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
"Only expand double or int64 scalars or vectors");

// Determine if we're dealing with a vector or scalar
bool IsVector = isa<FixedVectorType>(BufferTy);
if (IsVector) {
assert(cast<FixedVectorType>(BufferTy)->getNumElements() == 2 &&
"TypedBufferStore vector must be size 2");
unsigned ExtractNum = 2;
unsigned VecLen = 0;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
if (!IsRaw)
assert(VT->getNumElements() == 2 &&
"TypedBufferStore vector must be size 2");
VecLen = VT->getNumElements();
ExtractNum = VecLen * 2;
}

// Create the appropriate vector type for the result
Type *Int32Ty = Builder.getInt32Ty();
Type *ResultTy = VectorType::get(Int32Ty, IsVector ? 4 : 2, false);
Type *ResultTy = VectorType::get(Int32Ty, ExtractNum, false);
Value *Val = PoisonValue::get(ResultTy);

Type *SplitElementTy = Int32Ty;
if (IsVector)
SplitElementTy = VectorType::get(SplitElementTy, 2, false);
SplitElementTy = VectorType::get(SplitElementTy, VecLen, false);

Value *LowBits = nullptr;
Value *HighBits = nullptr;
// Split the 64-bit values into 32-bit components
if (IsDouble) {
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
{Orig->getOperand(2)});
{Orig->getOperand(IsRaw ? 3 : 2)});
LowBits = Builder.CreateExtractValue(Split, 0);
HighBits = Builder.CreateExtractValue(Split, 1);
} else {
// Handle int64 type(s)
Value *InputVal = Orig->getOperand(2);
Value *InputVal = Orig->getOperand(IsRaw ? 3 : 2);
Constant *ShiftAmt = Builder.getInt64(32);
if (IsVector)
ShiftAmt = ConstantVector::getSplat(ElementCount::getFixed(2), ShiftAmt);
ShiftAmt =
ConstantVector::getSplat(ElementCount::getFixed(VecLen), ShiftAmt);

// Split into low and high 32-bit parts
LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
Expand All @@ -678,17 +717,42 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
}

if (IsVector) {
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
SmallVector<int, 8> Mask;
for (unsigned I = 0; I < VecLen; ++I) {
Mask.push_back(I);
Mask.push_back(I + VecLen);
}
Val = Builder.CreateShuffleVector(LowBits, HighBits, Mask);
} else {
Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
}

// Create the final intrinsic call
Builder.CreateIntrinsic(Builder.getVoidTy(),
Intrinsic::dx_resource_store_typedbuffer,
{Orig->getOperand(0), Orig->getOperand(1), Val});
unsigned Base = 0;
while (ExtractNum > 0) {
unsigned StoreNum = std::min(ExtractNum, 4u);

Intrinsic::ID StoreIntrinsic = Intrinsic::dx_resource_store_typedbuffer;
SmallVector<Value *, 4> Args = {Orig->getOperand(0), Orig->getOperand(1)};
if (IsRaw) {
StoreIntrinsic = Intrinsic::dx_resource_store_rawbuffer;
Value *Tmp = Builder.getInt32(4 * Base);
Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
}

SmallVector<int, 4> Mask;
for (unsigned I = 0; I < StoreNum; ++I) {
Mask.push_back(Base + I);
}
Value *SubVal = Builder.CreateShuffleVector(Val, Mask);

Args.push_back(SubVal);
// Create the final intrinsic call
Builder.CreateIntrinsic(Builder.getVoidTy(), StoreIntrinsic, Args);

ExtractNum -= StoreNum;
Base += StoreNum;
}
Orig->eraseFromParent();
return true;
}
Expand Down Expand Up @@ -821,12 +885,20 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_radians:
Result = expandRadiansIntrinsic(Orig);
break;
case Intrinsic::dx_resource_load_rawbuffer:
if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ true))
return true;
break;
case Intrinsic::dx_resource_store_rawbuffer:
if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ true))
return true;
break;
case Intrinsic::dx_resource_load_typedbuffer:
if (expandTypedBufferLoadIntrinsic(Orig))
if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ false))
return true;
break;
case Intrinsic::dx_resource_store_typedbuffer:
if (expandTypedBufferStoreIntrinsic(Orig))
if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ false))
return true;
break;
case Intrinsic::usub_sat:
Expand Down
8 changes: 6 additions & 2 deletions llvm/test/CodeGen/DirectX/BufferStoreDouble.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ define void @storef64(double %0) {
; CHECK: [[Hi:%.*]] = extractvalue { i32, i32 } [[SD]], 1
; CHECK: [[Vec1:%.*]] = insertelement <2 x i32> poison, i32 [[Lo]], i32 0
; CHECK: [[Vec2:%.*]] = insertelement <2 x i32> [[Vec1]], i32 [[Hi]], i32 1
; this shufflevector is unnecessary but generated to avoid specalization
; CHECK: [[Vec3:%.*]] = shufflevector <2 x i32> [[Vec2]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
; CHECK: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_f64_1_0_0t.v2i32(
; CHECK-SAME: target("dx.TypedBuffer", double, 1, 0, 0) [[B]], i32 0, <2 x i32> [[Vec2]])
; CHECK-SAME: target("dx.TypedBuffer", double, 1, 0, 0) [[B]], i32 0, <2 x i32> [[Vec3]])
call void @llvm.dx.resource.store.typedbuffer(
target("dx.TypedBuffer", double, 1, 0, 0) %buffer, i32 0,
double %0)
Expand All @@ -38,8 +40,10 @@ define void @storev2f64(<2 x double> %0) {
; CHECK: [[Lo:%.*]] = extractvalue { <2 x i32>, <2 x i32> } [[SD]], 0
; CHECK: [[Hi:%.*]] = extractvalue { <2 x i32>, <2 x i32> } [[SD]], 1
; CHECK: [[Vec:%.*]] = shufflevector <2 x i32> [[Lo]], <2 x i32> [[Hi]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
; this shufflevector is unnecessary but generated to avoid specalization
; CHECK: [[Vec2:%.*]] = shufflevector <4 x i32> [[Vec]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2f64_1_0_0t.v4i32(
; CHECK-SAME: target("dx.TypedBuffer", <2 x double>, 1, 0, 0) [[B]], i32 0, <4 x i32> [[Vec]])
; CHECK-SAME: target("dx.TypedBuffer", <2 x double>, 1, 0, 0) [[B]], i32 0, <4 x i32> [[Vec2]])
call void @llvm.dx.resource.store.typedbuffer(
target("dx.TypedBuffer", <2 x double>, 1, 0, 0) %buffer, i32 0,
<2 x double> %0)
Expand Down
8 changes: 6 additions & 2 deletions llvm/test/CodeGen/DirectX/BufferStoreInt64.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ define void @storei64(i64 %0) {
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[TMP3]] to i32
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP2]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP4]], i32 1
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP6]])
; the shufflevector is unnecessary but generated to avoid too much specalization
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP7]])
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
Expand All @@ -29,7 +31,9 @@ define void @storev2i64(<2 x i64> %0) {
; CHECK-NEXT: [[TMP3:%.*]] = lshr <2 x i64> [[TMP0]], splat (i64 32)
; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP3]] to <2 x i32>
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP4]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP13]])
; the shufflevector is unnecessary but generated to avoid too much specalization
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP14]])
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
Expand Down
24 changes: 0 additions & 24 deletions llvm/test/CodeGen/DirectX/RawBufferLoad-error64.ll

This file was deleted.

Loading
Loading