Skip to content

Commit 43f17ed

Browse files
committed
Addressed PR feedback
1 parent b5ee3cb commit 43f17ed

File tree

7 files changed

+46
-23
lines changed

7 files changed

+46
-23
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4894,7 +4894,7 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
48944894
def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> {
48954895
let Spellings = ["__builtin_hlsl_dot2add"];
48964896
let Attributes = [NoThrow, Const, CustomTypeChecking];
4897-
let Prototype = "void(...)";
4897+
let Prototype = "float(_ExtVector<2, _Float16>,_ExtVector<2, _Float16>, float)";
48984898
}
48994899

49004900
def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> {

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19683,18 +19683,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1968319683
}
1968419684
case Builtin::BI__builtin_hlsl_dot2add: {
1968519685
llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch();
19686-
if (Arch != llvm::Triple::dxil) {
19687-
llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil");
19688-
}
19686+
assert(Arch == llvm::Triple::dxil && "Intrinsic dot2add can be executed as a builtin only on dxil");
1968919687
Value *A = EmitScalarExpr(E->getArg(0));
1969019688
Value *B = EmitScalarExpr(E->getArg(1));
1969119689
Value *C = EmitScalarExpr(E->getArg(2));
1969219690

19693-
//llvm::Intrinsic::dx_##IntrinsicPostfix
1969419691
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
1969519692
return Builder.CreateIntrinsic(
1969619693
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
19697-
"hlsl.dot2add");
19694+
"dx.dot2add");
1969819695
}
1969919696
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
1970019697
Value *A = EmitScalarExpr(E->getArg(0));

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,15 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X,
121121
// dot2add builtins
122122
//===----------------------------------------------------------------------===//
123123

124-
/// \fn float dot2add(half2 a, half2 b, float c)
124+
/// \fn float dot2add(half2 A, half2 B, float C)
125125
/// \brief Dot product of 2 vector of type half and add a float scalar value.
126+
/// \param A The first input value to dot product.
127+
/// \param B The second input value to dot product.
128+
/// \param C The input value added to the dot product.
126129

127130
_HLSL_AVAILABILITY(shadermodel, 6.4)
128-
const inline float dot2add(half2 a, half2 b, float c) {
129-
return __detail::dot2add_impl(a, b, c);
131+
const inline float dot2add(half2 A, half2 B, float C) {
132+
return __detail::dot2add_impl(A, B, C);
130133
}
131134

132135
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2473,11 +2473,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24732473
break;
24742474
}
24752475
case Builtin::BI__builtin_hlsl_dot2add: {
2476-
// Check number of arguments should be 3
24772476
if (SemaRef.checkArgCount(TheCall, 3))
24782477
return true;
24792478

2480-
// Check first two arguments are vector of length 2 with half data type
24812479
auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
24822480
if (const auto *VecTy = PassedType->getAs<VectorType>())
24832481
return !(VecTy->getNumElements() == 2 &&
@@ -2493,10 +2491,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24932491
checkHalfVectorOfSize2))
24942492
return true;
24952493

2496-
// Check third argument is a float
24972494
if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy))
24982495
return true;
2499-
TheCall->setType(TheCall->getArg(2)->getType());
2496+
TheCall->setType(SemaRef.getASTContext().FloatTy);
25002497
break;
25012498
}
25022499
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
// RUN: %clang_cc1 -finclude-default-header -triple \
1+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
22
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
33
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4-
// RUN: %clang_cc1 -finclude-default-header -triple \
4+
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \
55
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
66
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
77

88
// Test basic lowering to runtime function call.
99

1010
float test(half2 p1, half2 p2, float p3) {
11-
// CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2)
11+
// CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} half @llvm.spv.fdot.v2f16(<2 x half> %1, <2 x half> %2)
12+
// CHECK-SPIRV: %[[CONVERT:.*]] = fpext {{.*}} half %[[MUL:.*]] to float
1213
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4
13-
// CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]]
14-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2)
14+
// CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[CONVERT:.*]], %[[C:.*]]
15+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %0, <2 x half> %1, float %2)
1516
// CHECK: ret float %[[RES]]
1617
return dot2add(p1, p2, p3);
17-
}
18+
}
Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
1-
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify
22

3-
bool test_too_few_arg() {
3+
float test_too_few_arg() {
44
return __builtin_hlsl_dot2add();
55
// expected-error@-1 {{too few arguments to function call, expected 3, have 0}}
66
}
77

8-
bool test_too_many_arg(half2 p1, half2 p2, float p3) {
8+
float test_too_many_arg(half2 p1, half2 p2, float p3) {
99
return __builtin_hlsl_dot2add(p1, p2, p3, p1);
1010
// expected-error@-1 {{too many arguments to function call, expected 3, have 4}}
1111
}
12+
13+
float test_float_arg2_type(half2 p1, float2 p2, float p3) {
14+
return __builtin_hlsl_dot2add(p1, p2, p3);
15+
// expected-error@-1 {{passing 'float2' (aka 'vector<float, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(half)))) half' (vector of 2 'half' values)}}
16+
}
17+
18+
float test_float_arg1_type(float2 p1, half2 p2, float p3) {
19+
return __builtin_hlsl_dot2add(p1, p2, p3);
20+
// expected-error@-1 {{passing 'float2' (aka 'vector<float, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(half)))) half' (vector of 2 'half' values)}}
21+
}
22+
23+
float test_double_arg3_type(half2 p1, half2 p2, double p3) {
24+
return __builtin_hlsl_dot2add(p1, p2, p3);
25+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'float'}}
26+
}
27+
28+
float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
29+
return __builtin_hlsl_dot2add(p1, p2, p3);
30+
// expected-error@-1 {{passing 'float2' (aka 'vector<float, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(half)))) half' (vector of 2 'half' values)}}
31+
}
32+
33+
float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
34+
return __builtin_hlsl_dot2add(p1, p2, p3);
35+
// expected-error@-1 {{passing 'int16_t2' (aka 'vector<int16_t, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(half)))) half' (vector of 2 'half' values)}}
36+
}

llvm/test/CodeGen/DirectX/dot2add.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ entry:
55
; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
66
%ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
77
ret float %ret
8-
}
8+
}

0 commit comments

Comments
 (0)