diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index f99e8e7ccdc5d..435b80ecaec64 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::vector_reduce_add: case Intrinsic::vector_reduce_fadd: return true; - case Intrinsic::dx_resource_load_typedbuffer: - // We need to handle doubles and vector of doubles. - return F.getReturnType() - ->getStructElementType(0) - ->getScalarType() - ->isDoubleTy(); - case Intrinsic::dx_resource_store_typedbuffer: - // We need to handle doubles and vector of doubles. - return F.getFunctionType()->getParamType(2)->getScalarType()->isDoubleTy(); + case Intrinsic::dx_resource_load_typedbuffer: { + // We need to handle i64, doubles, and vectors of them. + Type *ScalarTy = + F.getReturnType()->getStructElementType(0)->getScalarType(); + return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64); + } + case Intrinsic::dx_resource_store_typedbuffer: { + // We need to handle i64 and doubles and vectors of i64 and doubles. + Type *ScalarTy = F.getFunctionType()->getParamType(2)->getScalarType(); + return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64); + } } return false; } @@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) { IRBuilder<> Builder(Orig); Type *BufferTy = Orig->getType()->getStructElementType(0); - assert(BufferTy->getScalarType()->isDoubleTy() && - "Only expand double or double2"); + Type *ScalarTy = BufferTy->getScalarType(); + bool IsDouble = ScalarTy->isDoubleTy(); + assert(IsDouble || ScalarTy->isIntegerTy(64) && + "Only expand double or int64 scalars or vectors"); unsigned ExtractNum = 2; if (auto *VT = dyn_cast(BufferTy)) { assert(VT->getNumElements() == 2 && - "TypedBufferLoad double vector has wrong size"); + "TypedBufferLoad vector must be size 2"); ExtractNum = 4; } @@ -570,22 +575,42 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) { ExtractElements.push_back( Builder.CreateExtractElement(Extract, Builder.getInt32(I))); - // combine into double(s) + // combine into double(s) or int64(s) Value *Result = PoisonValue::get(BufferTy); for (unsigned I = 0; I < ExtractNum; I += 2) { - Value *Dbl = - Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble, - {ExtractElements[I], ExtractElements[I + 1]}); + Value *Combined = nullptr; + if (IsDouble) + // For doubles, use dx_asdouble intrinsic + Combined = + Builder.CreateIntrinsic(Builder.getDoubleTy(), Intrinsic::dx_asdouble, + {ExtractElements[I], ExtractElements[I + 1]}); + else { + // For int64, manually combine two int32s + // First, zero-extend both values to i64 + Value *Lo = Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty()); + Value *Hi = + Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty()); + // Shift the high bits left by 32 bits + Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32)); + // OR the high and low bits together + Combined = Builder.CreateOr(Lo, ShiftedHi); + } + if (ExtractNum == 4) - Result = - Builder.CreateInsertElement(Result, Dbl, Builder.getInt32(I / 2)); + Result = Builder.CreateInsertElement(Result, Combined, + Builder.getInt32(I / 2)); else - Result = Dbl; + Result = Combined; } Value *CheckBit = nullptr; for (User *U : make_early_inc_range(Orig->users())) { - auto *EVI = cast(U); + // If it's not a ExtractValueInst, we don't know how to + // handle it + auto *EVI = dyn_cast(U); + if (!EVI) + llvm_unreachable("Unexpected user of typedbufferload"); + ArrayRef Indices = EVI->getIndices(); assert(Indices.size() == 1); @@ -609,38 +634,61 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) { IRBuilder<> Builder(Orig); Type *BufferTy = Orig->getFunctionType()->getParamType(2); - assert(BufferTy->getScalarType()->isDoubleTy() && - "Only expand double or double2"); - - unsigned ExtractNum = 2; - if (auto *VT = dyn_cast(BufferTy)) { - assert(VT->getNumElements() == 2 && - "TypedBufferStore double vector has wrong size"); - ExtractNum = 4; + Type *ScalarTy = BufferTy->getScalarType(); + bool IsDouble = ScalarTy->isDoubleTy(); + assert((IsDouble || ScalarTy->isIntegerTy(64)) && + "Only expand double or int64 scalars or vectors"); + + // Determine if we're dealing with a vector or scalar + bool IsVector = isa(BufferTy); + if (IsVector) { + assert(cast(BufferTy)->getNumElements() == 2 && + "TypedBufferStore vector must be size 2"); } - Type *SplitElementTy = Builder.getInt32Ty(); - if (ExtractNum == 4) + // Create the appropriate vector type for the result + Type *Int32Ty = Builder.getInt32Ty(); + Type *ResultTy = VectorType::get(Int32Ty, IsVector ? 4 : 2, false); + Value *Val = PoisonValue::get(ResultTy); + + Type *SplitElementTy = Int32Ty; + if (IsVector) SplitElementTy = VectorType::get(SplitElementTy, 2, false); - // split our double(s) - auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy); - Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble, - Orig->getOperand(2)); - // create our vector - Value *LowBits = Builder.CreateExtractValue(Split, 0); - Value *HighBits = Builder.CreateExtractValue(Split, 1); - Value *Val; - if (ExtractNum == 2) { - Val = PoisonValue::get(VectorType::get(SplitElementTy, 2, false)); + Value *LowBits = nullptr; + Value *HighBits = nullptr; + // Split the 64-bit values into 32-bit components + if (IsDouble) { + auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy); + Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble, + {Orig->getOperand(2)}); + LowBits = Builder.CreateExtractValue(Split, 0); + HighBits = Builder.CreateExtractValue(Split, 1); + } else { + // Handle int64 type(s) + Value *InputVal = Orig->getOperand(2); + Constant *ShiftAmt = Builder.getInt64(32); + if (IsVector) + ShiftAmt = ConstantVector::getSplat(ElementCount::getFixed(2), ShiftAmt); + + // Split into low and high 32-bit parts + LowBits = Builder.CreateTrunc(InputVal, SplitElementTy); + Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt); + HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy); + } + + if (IsVector) { + Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3}); + } else { Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0)); Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1)); - } else - Val = Builder.CreateShuffleVector(LowBits, HighBits, {0, 2, 1, 3}); + } + // Create the final intrinsic call Builder.CreateIntrinsic(Builder.getVoidTy(), Intrinsic::dx_resource_store_typedbuffer, {Orig->getOperand(0), Orig->getOperand(1), Val}); + Orig->eraseFromParent(); return true; } diff --git a/llvm/test/CodeGen/DirectX/BufferLoadDouble.ll b/llvm/test/CodeGen/DirectX/BufferLoadDouble.ll index 80a071a66364b..25abf2111060c 100644 --- a/llvm/test/CodeGen/DirectX/BufferLoadDouble.ll +++ b/llvm/test/CodeGen/DirectX/BufferLoadDouble.ll @@ -88,4 +88,4 @@ define void @loadf64WithCheckBit() { ; CHECK-NOT: extractvalue { double, i1 } %cb = extractvalue {double, i1} %load0, 1 ret void -} \ No newline at end of file +} diff --git a/llvm/test/CodeGen/DirectX/BufferLoadInt64.ll b/llvm/test/CodeGen/DirectX/BufferLoadInt64.ll new file mode 100644 index 0000000000000..42c0012ff3475 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferLoadInt64.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.2-compute" + +define void @loadi64() { +; CHECK-LABEL: define void @loadi64() { +; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) +; CHECK-NEXT: [[TMP1:%.*]] = call { <2 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v2i32.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <2 x i32>, i1 } [[TMP1]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1 +; CHECK-NEXT: [[TMP5:%.*]] = zext i32 [[TMP3]] to i64 +; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP4]] to i64 +; CHECK-NEXT: [[TMP7:%.*]] = shl i64 [[TMP6]], 32 +; CHECK-NEXT: [[TMP8:%.*]] = or i64 [[TMP5]], [[TMP7]] +; CHECK-NEXT: ret void +; + %buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) + %result = call { i64, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0) + ret void +} + +define void @loadv2i64() { +; CHECK-LABEL: define void @loadv2i64() { +; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) +; CHECK-NEXT: [[TMP1:%.*]] = call { <4 x i32>, i1 } @llvm.dx.resource.load.typedbuffer.v4i32.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0) +; CHECK-NEXT: [[TMP2:%.*]] = extractvalue { <4 x i32>, i1 } [[TMP1]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0 +; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[TMP2]], i32 1 +; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i32> [[TMP2]], i32 2 +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[TMP2]], i32 3 +; CHECK-NEXT: [[TMP7:%.*]] = zext i32 [[TMP3]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP4]] to i64 +; CHECK-NEXT: [[TMP9:%.*]] = shl i64 [[TMP8]], 32 +; CHECK-NEXT: [[TMP10:%.*]] = or i64 [[TMP7]], [[TMP9]] +; CHECK-NEXT: [[TMP11:%.*]] = insertelement <2 x i64> poison, i64 [[TMP10]], i32 0 +; CHECK-NEXT: [[TMP12:%.*]] = zext i32 [[TMP5]] to i64 +; CHECK-NEXT: [[TMP13:%.*]] = zext i32 [[TMP6]] to i64 +; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP13]], 32 +; CHECK-NEXT: [[TMP15:%.*]] = or i64 [[TMP12]], [[TMP14]] +; CHECK-NEXT: [[TMP16:%.*]] = insertelement <2 x i64> [[TMP11]], i64 [[TMP15]], i32 1 +; CHECK-NEXT: ret void +; + %buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) + %result = call { <2 x i64>, i1 } @llvm.dx.resource.load.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0) + ret void +} diff --git a/llvm/test/CodeGen/DirectX/BufferStoreInt64.ll b/llvm/test/CodeGen/DirectX/BufferStoreInt64.ll new file mode 100644 index 0000000000000..c97a02d1873a0 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferStoreInt64.ll @@ -0,0 +1,38 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -dxil-intrinsic-expansion %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +define void @storei64(i64 %0) { +; CHECK-LABEL: define void @storei64( +; CHECK-SAME: i64 [[TMP0:%.*]]) { +; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP3:%.*]] = lshr i64 [[TMP0]], 32 +; CHECK-NEXT: [[TMP4:%.*]] = trunc i64 [[TMP3]] to i32 +; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32> poison, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32> [[TMP5]], i32 [[TMP4]], i32 1 +; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t.v2i32(target("dx.TypedBuffer", i64, 1, 0, 0) [[BUFFER]], i32 0, <2 x i32> [[TMP6]]) +; CHECK-NEXT: ret void +; + %buffer = tail call target("dx.TypedBuffer", i64, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) + call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_i64_1_0_0t(target("dx.TypedBuffer", i64, 1, 0, 0) %buffer, i32 0,i64 %0) + ret void +} + + +define void @storev2i64(<2 x i64> %0) { +; CHECK-LABEL: define void @storev2i64( +; CHECK-SAME: <2 x i64> [[TMP0:%.*]]) { +; CHECK-NEXT: [[BUFFER:%.*]] = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) +; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[TMP0]] to <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = lshr <2 x i64> [[TMP0]], splat (i64 32) +; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP3]] to <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP2]], <2 x i32> [[TMP4]], <4 x i32> +; CHECK-NEXT: call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t.v4i32(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) [[BUFFER]], i32 0, <4 x i32> [[TMP13]]) +; CHECK-NEXT: ret void +; + %buffer = tail call target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2i64_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr null) + call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v2i64_1_0_0t(target("dx.TypedBuffer", <2 x i64>, 1, 0, 0) %buffer, i32 0, <2 x i64> %0) + ret void +}