Skip to content

Commit 91bcdc4

Browse files
adding dxil codegen
1 parent 7593c79 commit 91bcdc4

File tree

7 files changed

+57
-105
lines changed

7 files changed

+57
-105
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18960,67 +18960,41 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1896018960
E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
1896118961
E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
1896218962
"asuint operands types mismatch");
18963-
1896418963
Value *Op0 = EmitScalarExpr(E->getArg(0));
1896518964
const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
1896618965
const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
1896718966

18968-
auto emitSplitDouble =
18969-
[](CGBuilderTy *Builder, llvm::Intrinsic::ID intrId, llvm::Value *arg,
18970-
llvm::Type *retType) -> std::pair<Value *, Value *> {
18971-
CallInst *CI =
18972-
Builder->CreateIntrinsic(retType, intrId,
18973-
{arg}, nullptr, "hlsl.asuint");
18974-
18975-
Value *arg0 = Builder->CreateExtractValue(CI, 0);
18976-
Value *arg1 = Builder->CreateExtractValue(CI, 1);
18977-
18978-
return std::make_pair(arg0, arg1);
18979-
};
18980-
1898118967
CallArgList Args;
1898218968
auto [Op1BaseLValue, Op1TmpLValue] =
1898318969
EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
1898418970
auto [Op2BaseLValue, Op2TmpLValue] =
1898518971
EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
1898618972

18987-
llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
18973+
if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) {
1898818974

18989-
if (!Op0->getType()->isVectorTy()) {
18990-
auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), Op0, retType);
18991-
18992-
Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
18993-
auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
18994-
18995-
EmitWritebacks(*this, Args);
18996-
return s;
18997-
}
18975+
llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty);
1899818976

18999-
auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
18977+
if (Op0->getType()->isVectorTy()) {
18978+
auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>();
1900018979

19001-
llvm::VectorType *i32VecTy = llvm::VectorType::get(
19002-
Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
18980+
llvm::VectorType *i32VecTy = llvm::VectorType::get(
18981+
Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
18982+
retType = llvm::StructType::get(i32VecTy, i32VecTy);
18983+
}
1900318984

19004-
std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr);
18985+
CallInst *CI =
18986+
Builder.CreateIntrinsic(retType, Intrinsic::dx_splitdouble, {Op0},
18987+
nullptr, "hlsl.splitdouble");
1900518988

19006-
for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) {
19007-
Value *op = Builder.CreateExtractElement(Op0, idx);
18989+
Value *arg0 = Builder.CreateExtractValue(CI, 0);
18990+
Value *arg1 = Builder.CreateExtractValue(CI, 1);
1900818991

19009-
auto [arg0, arg1] = emitSplitDouble(&Builder, CGM.getHLSLRuntime().getSplitdoubleIntrinsic(), op, retType);
18992+
Builder.CreateStore(arg0, Op1TmpLValue.getAddress());
18993+
auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress());
1901018994

19011-
if (idx == 0) {
19012-
inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx);
19013-
inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx);
19014-
} else {
19015-
inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx);
19016-
inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx);
19017-
}
18995+
EmitWritebacks(*this, Args);
18996+
return s;
1901818997
}
19019-
19020-
Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress());
19021-
auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress());
19022-
EmitWritebacks(*this, Args);
19023-
return s;
1902418998
}
1902518999
}
1902619000
return nullptr;

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ class CGHLSLRuntime {
8888
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
8989
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
9090
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
91-
GENERATE_HLSL_INTRINSIC_FUNCTION(Splitdouble, splitdouble);
9291
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9392
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
9493
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)

clang/test/CodeGenHLSL/builtins/splitdouble.hlsl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s
2-
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv--vulkan-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s --check-prefix=SPIRV
32

43

54

6-
// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
7-
// CHECK: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
5+
// CHECK: define {{.*}} i32 {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
6+
// CHECK: [[VALRET:%.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]])
87
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
98
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
10-
// SPIRV: define spir_func {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]])
11-
// SPIRV-NOT: @llvm.dx.splitdouble
12-
// SPIRV: [[REG:%.*]] = load double, ptr [[VALD]].addr
13-
// SPIRV: call spir_func void {{.*}}asuint{{.*}}(double {{.*}} [[REG]], {{.*}})
14-
float test_scalar(double D) {
9+
uint test_scalar(double D) {
1510
uint A, B;
1611
asuint(D, A, B);
1712
return A + B;
1813
}
1914

2015

21-
// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
22-
// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]]
23-
// CHECK-NEXT: [[VALRET:%hlsl.asuint.*]] = {{.*}} call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]])
24-
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0
25-
// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1
26-
// SPIRV: define spir_func {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
27-
// SPIRV-NOT: @llvm.dx.splitdouble
28-
// SPIRV: [[REG:%.*]] = load <3 x double>, ptr [[VALD]].addr
29-
// SPIRV: call spir_func void {{.*}}asuint{{.*}}(<3 x double> {{.*}} [[REG]], {{.*}})
30-
float3 test_vector(double3 D) {
16+
// CHECK: define {{.*}} <3 x i32> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]])
17+
// CHECK: [[VALRET:%.*]] = {{.*}} call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> [[VALD]])
18+
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 0
19+
// CHECK-NEXT: extractvalue { <3 x i32>, <3 x i32> } [[VALRET]], 1
20+
uint3 test_vector(double3 D) {
3121
uint3 A, B;
3222
asuint(D, A, B);
3323
return A + B;

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,6 @@ def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>
9595

9696
def int_dx_splitdouble : DefaultAttrsIntrinsic<
9797
[llvm_anyint_ty, LLVMMatchType<0>],
98-
[llvm_double_ty],
98+
[LLVMScalarOrSameVectorWidth<0, llvm_double_ty>],
9999
[IntrNoMem, IntrWillReturn]>;
100100
}

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,12 @@ class OpLowerer {
493493

494494
Value *Arg0 = CI->getArgOperand(0);
495495

496+
if (Arg0->getType()->isVectorTy()) {
497+
return make_error<StringError>(
498+
"splitdouble doesn't support lowering vector types.",
499+
inconvertibleErrorCode());
500+
}
501+
496502
Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext());
497503

498504
std::array<Value *, 1> Args{Arg0};
Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,17 @@
1-
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
2-
; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
1+
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
2+
; RUN: opt -S --scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
33

4-
; Make sure DXILOpLowering is correctly generating the dxil op code call, with and without scalarizer.
4+
; Make sure DXILOpLowering is correctly generating the dxil op, with and without scalarizer.
55

6-
; CHECK-LABEL: define noundef float @test_scalar_double_split
7-
define noundef float @test_scalar_double_split(double noundef %D) local_unnamed_addr {
6+
; CHECK-LABEL: define noundef i32 @test_scalar_double_split
7+
define noundef i32 @test_scalar_double_split(double noundef %D) local_unnamed_addr {
88
entry:
99
; CHECK: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double %D)
1010
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
1111
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
12-
%hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
13-
%0 = extractvalue { i32, i32 } %hlsl.asuint, 0
14-
%1 = extractvalue { i32, i32 } %hlsl.asuint, 1
12+
%hlsl.splitdouble = call { i32, i32 } @llvm.dx.splitdouble.i32(double %D)
13+
%0 = extractvalue { i32, i32 } %hlsl.splitdouble, 0
14+
%1 = extractvalue { i32, i32 } %hlsl.splitdouble, 1
1515
%add = add i32 %0, %1
16-
%conv = uitofp i32 %add to float
17-
ret float %conv
18-
}
19-
20-
declare <2 x i32> @llvm.dx.splitdouble.v2i32(double) #1
21-
22-
23-
; CHECK-LABEL: define noundef <3 x float> @test_vector_double_split
24-
define noundef <3 x float> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
25-
entry:
26-
%0 = extractelement <3 x double> %D, i64 0
27-
; CHECK-COUNT-3: [[CALL:%.*]] = call %dx.types.splitdouble @dx.op.splitDouble.f64(i32 102, double {{.*}})
28-
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
29-
; CHECK-NEXT:extractvalue %dx.types.splitdouble [[CALL]], {{[0-1]}}
30-
%hlsl.asuint = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %0)
31-
%1 = extractvalue { i32, i32 } %hlsl.asuint, 0
32-
%2 = extractvalue { i32, i32 } %hlsl.asuint, 1
33-
%3 = insertelement <3 x i32> poison, i32 %1, i64 0
34-
%4 = insertelement <3 x i32> poison, i32 %2, i64 0
35-
%5 = extractelement <3 x double> %D, i64 1
36-
%hlsl.asuint2 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %5)
37-
%6 = extractvalue { i32, i32 } %hlsl.asuint2, 0
38-
%7 = extractvalue { i32, i32 } %hlsl.asuint2, 1
39-
%8 = insertelement <3 x i32> %3, i32 %6, i64 1
40-
%9 = insertelement <3 x i32> %4, i32 %7, i64 1
41-
%10 = extractelement <3 x double> %D, i64 2
42-
%hlsl.asuint3 = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double %10)
43-
%11 = extractvalue { i32, i32 } %hlsl.asuint3, 0
44-
%12 = extractvalue { i32, i32 } %hlsl.asuint3, 1
45-
%13 = insertelement <3 x i32> %8, i32 %11, i64 2
46-
%14 = insertelement <3 x i32> %9, i32 %12, i64 2
47-
%add = add <3 x i32> %13, %14
48-
%conv = uitofp <3 x i32> %add to <3 x float>
49-
ret <3 x float> %conv
16+
ret i32 %add
5017
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s
2+
3+
; DXIL operation splitdouble doesn't support vector types.
4+
; CHECK: in function test_vector_double_split
5+
; CHECK-SAME: splitdouble doesn't support lowering vector types.
6+
7+
define noundef <3 x i32> @test_vector_double_split(<3 x double> noundef %D) local_unnamed_addr {
8+
entry:
9+
%hlsl.splitdouble = tail call { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double> %D)
10+
%0 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 0
11+
%1 = extractvalue { <3 x i32>, <3 x i32> } %hlsl.splitdouble, 1
12+
%add = add <3 x i32> %0, %1
13+
ret <3 x i32> %add
14+
}
15+
16+
declare { <3 x i32>, <3 x i32> } @llvm.dx.splitdouble.v3i32(<3 x double>)

0 commit comments

Comments
 (0)