Skip to content

Commit 999c925

Browse files
committed
[HLSL] select scalar overloads for vector conditions
This PR adds scalar/vector overloads for vector conditions to the `select` builtin, and updates the sema checking and codegen to allow scalars to extend to vectors. Fixes #126570
1 parent b65e094 commit 999c925

File tree

7 files changed

+135
-100
lines changed

7 files changed

+135
-100
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12682,6 +12682,9 @@ def err_hlsl_param_qualifier_mismatch :
1268212682
def err_hlsl_vector_compound_assignment_truncation : Error<
1268312683
"left hand operand of type %0 to compound assignment cannot be truncated "
1268412684
"when used with right hand operand of type %1">;
12685+
def err_hlsl_builtin_scalar_vector_mismatch : Error<
12686+
"%select{all|second and third}0 arguments to %1 must be of scalar or "
12687+
"vector type with matching scalar element type%diff{: $ vs $|}2,3">;
1268512688

1268612689
def warn_hlsl_impcast_vector_truncation : Warning<
1268712690
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19741,6 +19741,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1974119741
RValFalse.isScalar()
1974219742
? RValFalse.getScalarVal()
1974319743
: RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
19744+
if (auto *VTy = E->getType()->getAs<VectorType>()) {
19745+
if (!OpTrue->getType()->isVectorTy())
19746+
OpTrue =
19747+
Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
19748+
if (!OpFalse->getType()->isVectorTy())
19749+
OpFalse =
19750+
Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
19751+
}
1974419752

1974519753
Value *SelectVal =
1974619754
Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
9595
#endif
9696
}
9797

98+
template<typename T>
99+
struct is_arithmetic {
100+
static const bool Value = __is_arithmetic(T);
101+
};
102+
98103
} // namespace __detail
99104
} // namespace hlsl
100105
#endif //_HLSL_HLSL_DETAILS_H_

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2246,6 +2246,42 @@ template <typename T, int Sz>
22462246
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
22472247
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, vector<T, Sz>);
22482248

2249+
2250+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, T TrueVal,
2251+
/// vector<T,Sz> FalseVals)
2252+
/// \brief ternary operator for vectors. All vectors must be the same size.
2253+
/// \param Conds The Condition input values.
2254+
/// \param TrueVal The scalar value to splat from when conditions are true.
2255+
/// \param FalseVals The vector values are chosen from when conditions are
2256+
/// false.
2257+
2258+
template <typename T, int Sz>
2259+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2260+
vector<T, Sz> select(vector<bool, Sz>, T, vector<T, Sz>);
2261+
2262+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2263+
/// T FalseVal)
2264+
/// \brief ternary operator for vectors. All vectors must be the same size.
2265+
/// \param Conds The Condition input values.
2266+
/// \param TrueVals The vector values are chosen from when conditions are true.
2267+
/// \param FalseVal The scalar value to splat from when conditions are false.
2268+
2269+
template <typename T, int Sz>
2270+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2271+
vector<T, Sz> select(vector<bool, Sz>, vector<T, Sz>, T);
2272+
2273+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals,
2274+
/// T FalseVal)
2275+
/// \brief ternary operator for vectors. All vectors must be the same size.
2276+
/// \param Conds The Condition input values.
2277+
/// \param TrueVal The scalar value to splat from when conditions are true.
2278+
/// \param FalseVal The scalar value to splat from when conditions are false.
2279+
2280+
template <typename T, int Sz>
2281+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
2282+
__detail::enable_if_t<__detail::is_arithmetic<T>::Value, vector<T, Sz>> select(
2283+
vector<bool, Sz>, T, T);
2284+
22492285
//===----------------------------------------------------------------------===//
22502286
// sin builtins
22512287
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,40 +2213,48 @@ static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
22132213
static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
22142214
assert(TheCall->getNumArgs() == 3);
22152215
Expr *Arg1 = TheCall->getArg(1);
2216+
QualType Arg1Ty = Arg1->getType();
22162217
Expr *Arg2 = TheCall->getArg(2);
2217-
if (!Arg1->getType()->isVectorType()) {
2218-
S->Diag(Arg1->getBeginLoc(), diag::err_builtin_non_vector_type)
2219-
<< "Second" << TheCall->getDirectCallee() << Arg1->getType()
2218+
QualType Arg2Ty = Arg2->getType();
2219+
2220+
QualType Arg1ScalarTy = Arg1Ty;
2221+
if (auto VTy = Arg1ScalarTy->getAs<VectorType>())
2222+
Arg1ScalarTy = VTy->getElementType();
2223+
2224+
QualType Arg2ScalarTy = Arg2Ty;
2225+
if (auto VTy = Arg2ScalarTy->getAs<VectorType>())
2226+
Arg2ScalarTy = VTy->getElementType();
2227+
2228+
if (!S->Context.hasSameUnqualifiedType(Arg1ScalarTy, Arg2ScalarTy))
2229+
S->Diag(Arg1->getBeginLoc(), diag::err_hlsl_builtin_scalar_vector_mismatch)
2230+
<< /* second and third */ 1 << TheCall->getCallee() << Arg1Ty << Arg2Ty;
2231+
2232+
QualType Arg0Ty = TheCall->getArg(0)->getType();
2233+
unsigned Arg0Length = Arg0Ty->getAs<VectorType>()->getNumElements();
2234+
unsigned Arg1Length = Arg1Ty->isVectorType()
2235+
? Arg1Ty->getAs<VectorType>()->getNumElements()
2236+
: 0;
2237+
unsigned Arg2Length = Arg2Ty->isVectorType()
2238+
? Arg2Ty->getAs<VectorType>()->getNumElements()
2239+
: 0;
2240+
if (Arg1Length > 0 && Arg0Length != Arg1Length) {
2241+
S->Diag(TheCall->getBeginLoc(),
2242+
diag::err_typecheck_vector_lengths_not_equal)
2243+
<< Arg0Ty << Arg1Ty << TheCall->getArg(0)->getSourceRange()
22202244
<< Arg1->getSourceRange();
22212245
return true;
22222246
}
22232247

2224-
if (!Arg2->getType()->isVectorType()) {
2225-
S->Diag(Arg2->getBeginLoc(), diag::err_builtin_non_vector_type)
2226-
<< "Third" << TheCall->getDirectCallee() << Arg2->getType()
2227-
<< Arg2->getSourceRange();
2228-
return true;
2229-
}
2230-
2231-
if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), Arg2->getType())) {
2248+
if (Arg2Length > 0 && Arg0Length != Arg2Length) {
22322249
S->Diag(TheCall->getBeginLoc(),
2233-
diag::err_typecheck_call_different_arg_types)
2234-
<< Arg1->getType() << Arg2->getType() << Arg1->getSourceRange()
2250+
diag::err_typecheck_vector_lengths_not_equal)
2251+
<< Arg0Ty << Arg2Ty << TheCall->getArg(0)->getSourceRange()
22352252
<< Arg2->getSourceRange();
22362253
return true;
22372254
}
22382255

2239-
// caller has checked that Arg0 is a vector.
2240-
// check all three args have the same length.
2241-
if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
2242-
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
2243-
S->Diag(TheCall->getBeginLoc(),
2244-
diag::err_typecheck_vector_lengths_not_equal)
2245-
<< TheCall->getArg(0)->getType() << Arg1->getType()
2246-
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
2247-
return true;
2248-
}
2249-
TheCall->setType(Arg1->getType());
2256+
TheCall->setType(
2257+
S->getASTContext().getExtVectorType(Arg1ScalarTy, Arg0Length));
22502258
return false;
22512259
}
22522260

clang/test/CodeGenHLSL/builtins/select.hlsl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,32 @@ int3 test_select_vector_3(bool3 cond0, int3 tVals, int3 fVals) {
5252
int4 test_select_vector_4(bool4 cond0, int4 tVals, int4 fVals) {
5353
return select(cond0, tVals, fVals);
5454
}
55+
56+
// CHECK-LABEL: test_select_vector_scalar_vector
57+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
58+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
59+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> {{%.*}}
60+
// CHECK: ret <4 x i32> [[SELECT]]
61+
int4 test_select_vector_scalar_vector(bool4 cond0, int tVal, int4 fVals) {
62+
return select(cond0, tVal, fVals);
63+
}
64+
65+
// CHECK-LABEL: test_select_vector_vector_scalar
66+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
67+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
68+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> {{%.*}}, <4 x i32> [[SPLAT1]]
69+
// CHECK: ret <4 x i32> [[SELECT]]
70+
int4 test_select_vector_vector_scalar(bool4 cond0, int4 tVals, int fVal) {
71+
return select(cond0, tVals, fVal);
72+
}
73+
74+
// CHECK-LABEL: test_select_vector_scalar_scalar
75+
// CHECK: [[SPLAT_SRC1:%.*]] = insertelement <4 x i32> poison, i32 {{%.*}}, i64 0
76+
// CHECK: [[SPLAT1:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC1]], <4 x i32> poison, <4 x i32> zeroinitializer
77+
// CHECK: [[SPLAT_SRC2:%.*]] = insertelement <4 x i32> poison, i32 %3, i64 0
78+
// CHECK: [[SPLAT2:%.*]] = shufflevector <4 x i32> [[SPLAT_SRC2]], <4 x i32> poison, <4 x i32> zeroinitializer
79+
// CHECK: [[SELECT:%.*]] = select <4 x i1> {{%.*}}, <4 x i32> [[SPLAT1]], <4 x i32> [[SPLAT2]]
80+
// CHECK: ret <4 x i32> [[SELECT]]
81+
int4 test_select_vector_scalar_scalar(bool4 cond0, int tVal, int fVal) {
82+
return select(cond0, tVal, fVal);
83+
}
Lines changed: 22 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,65 @@
1-
// RUN: %clang_cc1 -finclude-default-header
2-
// -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only
3-
// -disable-llvm-passes -verify -verify-ignore-unexpected
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
42

5-
int test_no_arg() {
6-
return select();
7-
// expected-error@-1 {{no matching function for call to 'select'}}
8-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template
9-
// not viable: requires 3 arguments, but 0 were provided}}
10-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
11-
// viable: requires 3 arguments, but 0 were provided}}
12-
}
13-
14-
int test_too_few_args(bool p0) {
15-
return select(p0);
16-
// expected-error@-1 {{no matching function for call to 'select'}}
17-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
18-
// viable: requires 3 arguments, but 1 was provided}}
19-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
20-
// viable: requires 3 arguments, but 1 was provided}}
21-
}
22-
23-
int test_too_many_args(bool p0, int t0, int f0, int g0) {
24-
return select<int>(p0, t0, f0, g0);
25-
// expected-error@-1 {{no matching function for call to 'select'}}
26-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
27-
// viable: requires 3 arguments, but 4 were provided}}
28-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
29-
// viable: requires 3 arguments, but 4 were provided}}
30-
}
313

324
int test_select_first_arg_wrong_type(int1 p0, int t0, int f0) {
335
return select(p0, t0, f0);
34-
// expected-error@-1 {{no matching function for call to 'select'}}
35-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
36-
// viable: no known conversion from 'vector<int, 1>' (vector of 1 'int' value)
37-
// to 'bool' for 1st argument}}
38-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: could
39-
// not match 'vector<T, Sz>' against 'int'}}
406
}
417

428
int1 test_select_bool_vals_diff_vecs(bool p0, int1 t0, int1 f0) {
439
return select<int1>(p0, t0, f0);
44-
// expected-warning@-1 {{implicit conversion truncates vector:
45-
// 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
46-
// (vector of 1 'int' value)}}
4710
}
4811

4912
int2 test_select_vector_vals_not_vecs(bool2 p0, int t0,
5013
int f0) {
5114
return select(p0, t0, f0);
52-
// expected-error@-1 {{no matching function for call to 'select'}}
53-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored:
54-
// could not match 'vector<T, Sz>' against 'int'}}
55-
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not
56-
// viable: no known conversion from 'vector<bool, 2>'
57-
// (vector of 2 'bool' values) to 'bool' for 1st argument}}
5815
}
5916

6017
int1 test_select_vector_vals_wrong_size(bool2 p0, int1 t0, int1 f0) {
61-
return select<int,1>(p0, t0, f0); // produce warnings
62-
// expected-warning@-1 {{implicit conversion truncates vector:
63-
// 'vector<bool, 2>' (vector of 2 'bool' values) to 'vector<bool, 1>'
64-
// (vector of 1 'bool' value)}}
65-
// expected-warning@-2 {{implicit conversion truncates vector:
66-
// 'vector<int, 2>' (vector of 2 'int' values) to 'vector<int, 1>'
67-
// (vector of 1 'int' value)}}
18+
return select<int,1>(p0, t0, f0); // expected-warning{{implicit conversion truncates vector: 'bool2' (aka 'vector<bool, 2>') to 'vector<bool, 1>' (vector of 1 'bool' value)}}
19+
}
20+
21+
int test_select_no_args() {
22+
return __builtin_hlsl_select(); // expected-error{{too few arguments to function call, expected 3, have 0}}
23+
}
24+
25+
int test_select_builtin_wrong_arg_count(bool p0) {
26+
return __builtin_hlsl_select(p0); // expected-error{{too few arguments to function call, expected 3, have 1}}
6827
}
6928

7029
// __builtin_hlsl_select tests
71-
int test_select_builtin_wrong_arg_count(bool p0, int t0) {
72-
return __builtin_hlsl_select(p0, t0);
73-
// expected-error@-1 {{too few arguments to function call, expected 3,
74-
// have 2}}
30+
int test_select_builtin_wrong_arg_count2(bool p0, int t0) {
31+
return __builtin_hlsl_select(p0, t0); // expected-error{{too few arguments to function call, expected 3, have 2}}
32+
}
33+
34+
int test_too_many_args(bool p0, int t0, int f0, int g0) {
35+
return __builtin_hlsl_select(p0, t0, f0, g0); // expected-error{{too many arguments to function call, expected 3, have 4}}
7536
}
7637

7738
// not a bool or a vector of bool. should be 2 errors.
7839
int test_select_builtin_first_arg_wrong_type1(int p0, int t0, int f0) {
79-
return __builtin_hlsl_select(p0, t0, f0);
80-
// expected-error@-1 {{passing 'int' to parameter of incompatible type
81-
// 'bool'}}
82-
// expected-error@-2 {{First argument to __builtin_hlsl_select must be of
83-
// vector type}}
84-
}
40+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int' where 'bool' or a vector of such type is required}}
41+
}
8542

8643
int test_select_builtin_first_arg_wrong_type2(int1 p0, int t0, int f0) {
87-
return __builtin_hlsl_select(p0, t0, f0);
88-
// expected-error@-1 {{passing 'vector<int, 1>' (vector of 1 'int' value) to
89-
// parameter of incompatible type 'bool'}}
90-
// expected-error@-2 {{First argument to __builtin_hlsl_select must be of
91-
// vector type}}
44+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{invalid operand of type 'int1' (aka 'vector<int, 1>') where 'bool' or a vector of such type is required}}
9245
}
9346

9447
// if a bool last 2 args are of same type
9548
int test_select_builtin_bool_incompatible_args(bool p0, int t0, double f0) {
96-
return __builtin_hlsl_select(p0, t0, f0);
97-
// expected-error@-1 {{arguments are of different types ('int' vs 'double')}}
49+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{arguments are of different types ('int' vs 'double')}}
9850
}
9951

10052
// if a vector second arg isnt a vector
10153
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int t0, int2 f0) {
10254
return __builtin_hlsl_select(p0, t0, f0);
103-
// expected-error@-1 {{Second argument to __builtin_hlsl_select must be of
104-
// vector type}}
10555
}
10656

10757
// if a vector third arg isn't a vector
10858
int2 test_select_builtin_second_arg_not_vector(bool2 p0, int2 t0, int f0) {
10959
return __builtin_hlsl_select(p0, t0, f0);
110-
// expected-error@-1 {{Third argument to __builtin_hlsl_select must be of
111-
// vector type}}
11260
}
11361

11462
// if vector last 2 aren't same type (so both are vectors but wrong type)
115-
int2 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
116-
return __builtin_hlsl_select(p0, t0, f0);
117-
// expected-error@-1 {{arguments are of different types ('vector<int, [...]>'
118-
// vs 'vector<float, [...]>')}}
63+
int1 test_select_builtin_diff_types(bool1 p0, int1 t0, float1 f0) {
64+
return __builtin_hlsl_select(p0, t0, f0); // expected-error{{second and third arguments to __builtin_hlsl_select must be of scalar or vector type with matching scalar element type: 'vector<int, [...]>' vs 'vector<float, [...]>'}}
11965
}

0 commit comments

Comments
 (0)