Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 98 additions & 42 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"

Expand Down Expand Up @@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::vector_reduce_add:
case Intrinsic::vector_reduce_fadd:
return true;
case Intrinsic::dx_resource_load_typedbuffer:
// We need to handle doubles and vector of doubles.
return F.getReturnType()
->getStructElementType(0)
->getScalarType()
->isDoubleTy();
case Intrinsic::dx_resource_store_typedbuffer:
// We need to handle doubles and vector of doubles.
return F.getFunctionType()->getParamType(2)->getScalarType()->isDoubleTy();
case Intrinsic::dx_resource_load_typedbuffer: {
// We need to handle i64, doubles, and vectors of them.
Type *ScalarTy =
F.getReturnType()->getStructElementType(0)->getScalarType();
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
}
case Intrinsic::dx_resource_store_typedbuffer: {
// We need to handle i64 and doubles and vectors of i64 and doubles.
Type *ScalarTy = F.getFunctionType()->getParamType(2)->getScalarType();
return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
}
}
return false;
}
Expand Down Expand Up @@ -541,17 +544,32 @@ static Value *expandRadiansIntrinsic(CallInst *Orig) {
return Builder.CreateFMul(X, PiOver180);
}

static Value *createCombinedi32toi64Expansion(IRBuilder<> &Builder,
Value *LoBytes,
Value *HighBytes) {
// For int64, manually combine two int32s
// First, zero-extend both values to i64
Value *Lo = Builder.CreateZExt(LoBytes, Builder.getInt64Ty());
Value *Hi = Builder.CreateZExt(HighBytes, Builder.getInt64Ty());
// Shift the high bits left by 32 bits
Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
// OR the high and low bits together
return Builder.CreateOr(Lo, ShiftedHi);
}

static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
IRBuilder<> Builder(Orig);

Type *BufferTy = Orig->getType()->getStructElementType(0);
assert(BufferTy->getScalarType()->isDoubleTy() &&
"Only expand double or double2");
Type *ScalarTy = BufferTy->getScalarType();
bool IsDouble = ScalarTy->isDoubleTy();
assert(IsDouble || ScalarTy->isIntegerTy(64) &&
"Only expand double or int64 scalars or vectors");

unsigned ExtractNum = 2;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
assert(VT->getNumElements() == 2 &&
"TypedBufferLoad double vector has wrong size");
"TypedBufferLoad vector must be size 2");
ExtractNum = 4;
}

Expand All @@ -570,22 +588,34 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
ExtractElements.push_back(
Builder.CreateExtractElement(Extract, Builder.getInt32(I)));

// combine into double(s)
// combine into double(s) or int64(s)
Value *Result = PoisonValue::get(BufferTy);
for (unsigned I = 0; I < ExtractNum; I += 2) {
Value *Dbl =
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
{ExtractElements[I], ExtractElements[I + 1]});
Value *Combined = nullptr;
if (IsDouble)
// For doubles, use dx_asdouble intrinsic
Combined =
Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble,
{ExtractElements[I], ExtractElements[I + 1]});
else
Combined = createCombinedi32toi64Expansion(Builder, ExtractElements[I],
ExtractElements[I + 1]);

if (ExtractNum == 4)
Result =
Builder.CreateInsertElement(Result, Dbl, Builder.getInt32(I / 2));
Result = Builder.CreateInsertElement(Result, Combined,
Builder.getInt32(I / 2));
else
Result = Dbl;
Result = Combined;
}

Value *CheckBit = nullptr;
for (User *U : make_early_inc_range(Orig->users())) {
auto *EVI = cast<ExtractValueInst>(U);
// If it's not a ExtractValueInst, we don't know how to
// handle it
auto *EVI = dyn_cast<ExtractValueInst>(U);
if (!EVI)
llvm_unreachable("Unexpected user of typedbufferload");

ArrayRef<unsigned> Indices = EVI->getIndices();
assert(Indices.size() == 1);

Expand All @@ -609,38 +639,64 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
IRBuilder<> Builder(Orig);

Type *BufferTy = Orig->getFunctionType()->getParamType(2);
assert(BufferTy->getScalarType()->isDoubleTy() &&
"Only expand double or double2");

unsigned ExtractNum = 2;
if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
assert(VT->getNumElements() == 2 &&
"TypedBufferStore double vector has wrong size");
ExtractNum = 4;
Type *ScalarTy = BufferTy->getScalarType();
bool IsDouble = ScalarTy->isDoubleTy();
assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
"Only expand double or int64 scalars or vectors");

// Determine if we're dealing with a vector or scalar
bool IsVector = isa<FixedVectorType>(BufferTy);
if (IsVector) {
assert(cast<FixedVectorType>(BufferTy)->getNumElements() == 2 &&
"TypedBufferStore vector must be size 2");
}

Type *SplitElementTy = Builder.getInt32Ty();
if (ExtractNum == 4)
// Create the appropriate vector type for the result
Type *Int32Ty = Builder.getInt32Ty();
Type *ResultTy = VectorType::get(Int32Ty, IsVector ? 4 : 2, false);
Value *Val = PoisonValue::get(ResultTy);

// Handle double type(s)
Type *SplitElementTy = Int32Ty;
if (IsVector)
SplitElementTy = VectorType::get(SplitElementTy, 2, false);

// split our double(s)
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
Orig->getOperand(2));
// create our vector
Value *LowBits = Builder.CreateExtractValue(Split, 0);
Value *HighBits = Builder.CreateExtractValue(Split, 1);
Value *Val;
if (ExtractNum == 2) {
Val = PoisonValue::get(VectorType::get(SplitElementTy, 2, false));
Value *LowBits = nullptr;
Value *HighBits = nullptr;
// Split the 64-bit values into 32-bit components
if (IsDouble) {
auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
{Orig->getOperand(2)});
LowBits = Builder.CreateExtractValue(Split, 0);
HighBits = Builder.CreateExtractValue(Split, 1);
} else {
// Handle int64 type(s)
Value *InputVal = Orig->getOperand(2);
Constant *ShiftAmt = Builder.getInt64(32);
if (IsVector)
ShiftAmt = ConstantVector::getSplat(ElementCount::getFixed(2), ShiftAmt);

// Split into low and high 32-bit parts
LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt);
HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy);
}

if (IsVector) {
// For vector doubles, use shuffle to create the final vector
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
} else {
// For scalar doubles, insert the elements
Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
} else
Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3});
}

// Create the final intrinsic call
Builder.CreateIntrinsic(Builder.getVoidTy(),
Intrinsic::dx_resource_store_typedbuffer,
{Orig->getOperand(0), Orig->getOperand(1), Val});

Orig->eraseFromParent();
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/DirectX/BufferLoadDouble.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ define void @loadf64WithCheckBit() {
; CHECK-NOT: extractvalue { double, i1 }
%cb = extractvalue {double, i1} %load0, 1
ret void
}
}
48 changes: 48 additions & 0 deletions llvm/test/CodeGen/DirectX/BufferLoadInt64.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.2-compute"

define void @loadi64() {
; CHECK-LABEL: define void @loadi64() {
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
; CHECK-NEXT: [[TMP1:%.*]] = call { <2 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v2i32.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0)
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <2 x i32>, i1 } [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
; CHECK-NEXT: [[TMP5:%.*]] = zext i32 [[TMP3]] to i64
; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64
; CHECK-NEXT: [[TMP7:%.*]] = shl i64 [[TMP6]], 32
; CHECK-NEXT: [[TMP8:%.*]] = or i64 [[TMP5]], [[TMP7]]
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
%result = call { i64, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0)
ret void
}

define void @loadv2i64() {
; CHECK-LABEL: define void @loadv2i64() {
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
; CHECK-NEXT: [[TMP1:%.*]] = call { <4 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v4i32.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0)
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <4 x i32>, i1 } [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[TMP2]], i32 1
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i32> [[TMP2]], i32 2
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[TMP2]], i32 3
; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP3]] to i64
; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP4]] to i64
; CHECK-NEXT: [[TMP9:%.*]] = shl i64 [[TMP8]], 32
; CHECK-NEXT: [[TMP10:%.*]] = or i64 [[TMP7]], [[TMP9]]
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> poison, i64 [[TMP10]], i32 0
; CHECK-NEXT: [[TMP12:%.*]] = zext i32 [[TMP5]] to i64
; CHECK-NEXT: [[TMP13:%.*]] = zext i32 [[TMP6]] to i64
; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP13]], 32
; CHECK-NEXT: [[TMP15:%.*]] = or i64 [[TMP12]], [[TMP14]]
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <2 x i64> [[TMP11]], i64 [[TMP15]], i32 1
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
%result = call { <2 x i64>, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0)
ret void
}
38 changes: 38 additions & 0 deletions llvm/test/CodeGen/DirectX/BufferStoreInt64.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s

target triple = "dxil-pc-shadermodel6.6-compute"

define void @storei64(i64 %0) {
; CHECK-LABEL: define void @storei64(
; CHECK-SAME: i64 [[TMP0:%.*]]) {
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP0]] to i32
; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP0]], 32
; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[TMP3]] to i32
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP2]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP4]], i32 1
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP6]])
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0,i64 %0)
ret void
}


define void @storev2i64(<2 x i64> %0) {
; CHECK-LABEL: define void @storev2i64(
; CHECK-SAME: <2 x i64> [[TMP0:%.*]]) {
; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[TMP0]] to <2 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = lshr <2 x i64> [[TMP0]], splat (i64 32)
; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP3]] to <2 x i32>
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP4]], <4 x i32> <i32 0, i32 2, i32 1, i32 3>
; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP13]])
; CHECK-NEXT: ret void
;
%buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null)
call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0, <2 x i64> %0)
ret void
}