Skip to content

Commit 7620f9f

Browse files
committed
[HLSL] select scalar overloads for vector conditions
This PR adds scalar/vector overloads for vector conditions to the `select` builtin, and updates the sema checking and codegen to allow scalars to extend to vectors. Fixes #126570 clang-format clang-format 'cbieneman/select' on '44f0fe9a2806'.
1 parent d90423e commit 7620f9f

File tree

7 files changed

+135
-100
lines changed

7 files changed

+135
-100
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12707,6 +12707,9 @@ def err_hlsl_param_qualifier_mismatch :
1270712707
def err_hlsl_vector_compound_assignment_truncation : Error<
1270812708
"left hand operand of type %0 to compound assignment cannot be truncated "
1270912709
"when used with right hand operand of type %1">;
12710+
def err_hlsl_builtin_scalar_vector_mismatch : Error<
12711+
"%select{all|second and third}0 arguments to %1 must be of scalar or "
12712+
"vector type with matching scalar element type%diff{: $ vs $|}2,3">;
1271012713

1271112714
def warn_hlsl_impcast_vector_truncation : Warning<
1271212715
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19836,6 +19836,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1983619836
RValFalse.isScalar()
1983719837
? RValFalse.getScalarVal()
1983819838
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
19839+
if (auto *VTy = E->getType()->getAs<VectorType>()) {
19840+
if (!OpTrue->getType()->isVectorTy())
19841+
OpTrue =
19842+
Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
19843+
if (!OpFalse->getType()->isVectorTy())
19844+
OpFalse =
19845+
Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
19846+
}
1983919847

1984019848
Value *SelectVal =
1984119849
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");

clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,6 +2123,42 @@ template <typename T, int Sz>
21232123
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
21242124
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
21252125

2126+
2127+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
2128+
/// vector<T,Sz> FalseVals)
2129+
/// \brief ternary operator for vectors. All vectors must be the same size.
2130+
/// \param Conds The Condition input values.
2131+
/// \param TrueVal The scalar value to splat from when conditions are true.
2132+
/// \param FalseVals The vector values are chosen from when conditions are
2133+
/// false.
2134+
2135+
template <typename T, int Sz>
2136+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2137+
vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
2138+
2139+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2140+
/// T FalseVal)
2141+
/// \brief ternary operator for vectors. All vectors must be the same size.
2142+
/// \param Conds The Condition input values.
2143+
/// \param TrueVals The vector values are chosen from when conditions are true.
2144+
/// \param FalseVal The scalar value to splat from when conditions are false.
2145+
2146+
template <typename T, int Sz>
2147+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2148+
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
2149+
2150+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2151+
/// T FalseVal)
2152+
/// \brief ternary operator for vectors. All vectors must be the same size.
2153+
/// \param Conds The Condition input values.
2154+
/// \param TrueVal The scalar value to splat from when conditions are true.
2155+
/// \param FalseVal The scalar value to splat from when conditions are false.
2156+
2157+
template <typename T, int Sz>
2158+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2159+
__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
2160+
vector<bool, Sz>, T, T);
2161+
21262162
//===----------------------------------------------------------------------===//
21272163
// sin builtins
21282164
//===----------------------------------------------------------------------===//

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
9797
#endif
9898
}
9999

100+
template<typename T>
101+
struct is_arithmetic {
102+
static const bool Value = __is_arithmetic(T);
103+
};
104+
100105
} // namespace __detail
101106
} // namespace hlsl
102107
#endif //_HLSL_HLSL_DETAILS_H_

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,40 +2225,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
22252225
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
22262226
assert(TheCall->getNumArgs() == 3);
22272227
Expr *Arg1 = TheCall->getArg(1);
2228+
QualType Arg1Ty = Arg1->getType();
22282229
Expr *Arg2 = TheCall->getArg(2);
2229-
if (!Arg1->getType()->isVectorType()) {
2230-
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
2231-
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
2230+
QualType Arg2Ty = Arg2->getType();
2231+
2232+
QualType Arg1ScalarTy = Arg1Ty;
2233+
if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2234+
Arg1ScalarTy = VTy->getElementType();
2235+
2236+
QualType Arg2ScalarTy = Arg2Ty;
2237+
if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2238+
Arg2ScalarTy = VTy->getElementType();
2239+
2240+
if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
2241+
S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
2242+
<< /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2243+
2244+
QualType Arg0Ty = TheCall->getArg(0)->getType();
2245+
unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2246+
unsigned Arg1Length = Arg1Ty->isVectorType()
2247+
? Arg1Ty->getAs<VectorType>()->getNumElements()
2248+
: 0;
2249+
unsigned Arg2Length = Arg2Ty->isVectorType()
2250+
? Arg2Ty->getAs<VectorType>()->getNumElements()
2251+
: 0;
2252+
if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2253+
S->Diag(TheCall->getBeginLoc(),
2254+
diag::err_typecheck_vector_lengths_not_equal)
2255+
<< Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
22322256
<< Arg1->getSourceRange();
22332257
return true;
22342258
}
22352259

2236-
if (!Arg2->getType()->isVectorType()) {
2237-
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
2238-
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
2239-
<< Arg2->getSourceRange();
2240-
return true;
2241-
}
2242-
2243-
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2260+
if (Arg2Length > 0 && Arg0Length != Arg2Length) {
22442261
S->Diag(TheCall->getBeginLoc(),
2245-
diag::err_typecheck_call_different_arg_types)
2246-
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2262+
diag::err_typecheck_vector_lengths_not_equal)
2263+
<< Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
22472264
<< Arg2->getSourceRange();
22482265
return true;
22492266
}
22502267

2251-
// caller has checked that Arg0 is a vector.
2252-
// check all three args have the same length.
2253-
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
2254-
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
2255-
S->Diag(TheCall->getBeginLoc(),
2256-
diag::err_typecheck_vector_lengths_not_equal)
2257-
<< TheCall->getArg(0)->getType() << Arg1->getType()
2258-
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
2259-
return true;
2260-
}
2261-
TheCall->setType(Arg1->getType());
2268+
TheCall->setType(
2269+
S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
22622270
return false;
22632271
}
22642272

clang/test/CodeGenHLSL/builtins/select.hlsl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
5252
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
5353
return select(cond0, tVals, fVals);
5454
}
55+
56+
// CHECK-LABEL: test_select_vector_scalar_vector
57+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
58+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
59+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
60+
// CHECK: ret <4 x i32> [[SELECT]]
61+
int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
62+
return select(cond0, tVal, fVals);
63+
}
64+
65+
// CHECK-LABEL: test_select_vector_vector_scalar
66+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
67+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
68+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
69+
// CHECK: ret <4 x i32> [[SELECT]]
70+
int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
71+
return select(cond0, tVals, fVal);
72+
}
73+
74+
// CHECK-LABEL: test_select_vector_scalar_scalar
75+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
76+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
77+
// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
78+
// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
79+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
80+
// CHECK: ret <4 x i32> [[SELECT]]
81+
int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
82+
return select(cond0, tVal, fVal);
83+
}
Lines changed: 22 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,65 @@
1-
// RUN: %clang_cc1 -finclude-default-header
2-
// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
3-
// -disable-llvm-passes -verify -verify-ignore-unexpected
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
42

5-
int test_no_arg() {
6-
return select();
7-
// expected-error@-1 {{no matching function for call to 'select'}}
8-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template
9-
// not viable: requires 3 arguments, but 0 were provided}}
10-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
11-
// viable: requires 3 arguments, but 0 were provided}}
12-
}
13-
14-
int test_too_few_args(bool p0) {
15-
return select(p0);
16-
// expected-error@-1 {{no matching function for call to 'select'}}
17-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
18-
// viable: requires 3 arguments, but 1 was provided}}
19-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
20-
// viable: requires 3 arguments, but 1 was provided}}
21-
}
22-
23-
int test_too_many_args(bool p0, int t0, int f0, int g0) {
24-
return select<int>(p0, t0, f0, g0);
25-
// expected-error@-1 {{no matching function for call to 'select'}}
26-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
27-
// viable: requires 3 arguments, but 4 were provided}}
28-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
29-
// viable: requires 3 arguments, but 4 were provided}}
30-
}
313

324
int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
335
return select(p0, t0, f0);
34-
// expected-error@-1 {{no matching function for call to 'select'}}
35-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
36-
// viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
37-
// to 'bool' for 1st argument}}
38-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
39-
// not match 'vector<T, Sz>' against 'int'}}
406
}
417

428
int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
439
return select<int1>(p0, t0, f0);
44-
// expected-warning@-1 {{implicit conversion truncates vector:
45-
// 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
46-
// (vector of 1 'int' value)}}
4710
}
4811

4912
int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
5013
int f0) {
5114
return select(p0, t0, f0);
52-
// expected-error@-1 {{no matching function for call to 'select'}}
53-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
54-
// could not match 'vector<T, Sz>' against 'int'}}
55-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
56-
// viable: no known conversion from 'vector<bool, 2>'
57-
// (vector of 2 'bool' values) to 'bool' for 1st argument}}
5815
}
5916

6017
int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
61-
return select<int,1>(p0, t0, f0); // produce warnings
62-
// expected-warning@-1 {{implicit conversion truncates vector:
63-
// 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
64-
// (vector of 1 'bool' value)}}
65-
// expected-warning@-2 {{implicit conversion truncates vector:
66-
// 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
67-
// (vector of 1 'int' value)}}
18+
return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
19+
}
20+
21+
int test_select_no_args() {
22+
return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
23+
}
24+
25+
int test_select_builtin_wrong_arg_count(bool p0) {
26+
return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
6827
}
6928

7029
// __builtin_hlsl_select tests
71-
int test_select_builtin_wrong_arg_count(bool p0, int t0) {
72-
return __builtin_hlsl_select(p0, t0);
73-
// expected-error@-1 {{too few arguments to function call, expected 3,
74-
// have 2}}
30+
int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
31+
return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
32+
}
33+
34+
int test_too_many_args(bool p0, int t0, int f0, int g0) {
35+
return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
7536
}
7637

7738
// not a bool or a vector of bool. should be 2 errors.
7839
int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
79-
return __builtin_hlsl_select(p0, t0, f0);
80-
// expected-error@-1 {{passing 'int' to parameter of incompatible type
81-
// 'bool'}}
82-
// expected-error@-2 {{First argument to __builtin_hlsl_select must be of
83-
// vector type}}
84-
}
40+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
41+
}
8542

8643
int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
87-
return __builtin_hlsl_select(p0, t0, f0);
88-
// expected-error@-1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
89-
// parameter of incompatible type 'bool'}}
90-
// expected-error@-2 {{First argument to __builtin_hlsl_select must be of
91-
// vector type}}
44+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
9245
}
9346

9447
// if a bool last 2 args are of same type
9548
int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
96-
return __builtin_hlsl_select(p0, t0, f0);
97-
// expected-error@-1 {{arguments are of different types ('int' vs 'double')}}
49+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
9850
}
9951

10052
// if a vector second arg isnt a vector
10153
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
10254
return __builtin_hlsl_select(p0, t0, f0);
103-
// expected-error@-1 {{Second argument to __builtin_hlsl_select must be of
104-
// vector type}}
10555
}
10656

10757
// if a vector third arg isn't a vector
10858
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
10959
return __builtin_hlsl_select(p0, t0, f0);
110-
// expected-error@-1 {{Third argument to __builtin_hlsl_select must be of
111-
// vector type}}
11260
}
11361

11462
// if vector last 2 aren't same type (so both are vectors but wrong type)
115-
int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
116-
return __builtin_hlsl_select(p0, t0, f0);
117-
// expected-error@-1 {{arguments are of different types ('vector<int, [...]>'
118-
// vs 'vector<float, [...]>')}}
63+
int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
64+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
11965
}

0 commit comments

Comments
 (0)