Skip to content

Commit a2a15fd

Browse files
author
Anagha Rajendra Rao
committed
add float16 type check
1 parent 2f9769b commit a2a15fd

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
7070
PassedType->isVectorType()
7171
? PassedType->castAs<clang::VectorType>()->getElementType()
7272
: PassedType;
73-
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
73+
if (!BaseType->isHalfType() && !BaseType->isFloat16Type() &&
74+
!BaseType->isFloat32Type())
7475
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
7576
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
7677
<< /* half or float */ 2 << PassedType;
@@ -80,7 +81,8 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
8081
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
8182
int ArgOrdinal,
8283
clang::QualType PassedType) {
83-
if (!PassedType->isHalfType() && !PassedType->isFloat32Type())
84+
if (!PassedType->isHalfType() && !PassedType->isFloat16Type() &&
85+
!PassedType->isFloat32Type())
8486
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
8587
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
8688
<< /* half or float */ 2 << PassedType;
@@ -287,7 +289,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
287289
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
288290
llvm::ArrayRef(ChecksArr)))
289291
return true;
290-
// Check that first two arguments are vectors of the same type
292+
// Check that first two arguments are vectors/scalars of the same type
291293
QualType Arg0Type = TheCall->getArg(0)->getType();
292294
if (!SemaRef.getASTContext().hasSameUnqualifiedType(
293295
Arg0Type, TheCall->getArg(1)->getType()))

clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ float1 test_vec1_inputs(float1 p0, float1 p1, float1 p2) {
5454
// expected-note@hlsl/hlsl_intrinsics.h:* {{candidate template ignored: substitution failure [with L = 1]: no type named 'Type' in 'hlsl::__detail::enable_if<false, float>'}}
5555
}
5656

57-
float3 test_mixed_datatype_inputs(float3 p0, float3 p1, half p2) {
58-
return refract(p0, p1, p2);
59-
}
60-
6157
typedef float float5 __attribute__((ext_vector_type(5)));
6258

6359
float5 test_vec5_inputs(float5 p0, float5 p1, float p2) {

clang/test/SemaSPIRV/BuiltIns/refract-errors.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ float test_int_scalar_inputs(int p0, int p1, int p2) {
2727

2828
float test_float_and_half_inputs(float2 p0, half2 p1, float p2) {
2929
return __builtin_spirv_refract(p0, p1, p2);
30-
// expected-error@-1 {{2nd argument must be a scalar or vector of 16 or 32 bit floating-point types (was 'half2' (vector of 2 'half' values))}}
30+
// expected-error@-1 {{first two arguments to '__builtin_spirv_refract' must have the same type}}
3131
}
3232

3333
float test_float_and_half_2_inputs(float2 p0, float2 p1, half p2) {
3434
return __builtin_spirv_refract(p0, p1, p2);
35-
// expected-error@-1 {{3rd argument must be a scalar 16 or 32 bit floating-point type (was 'half' (aka '_Float16'))}}
35+
// expected-error@-1 {{all arguments to '__builtin_spirv_refract' must be of scalar or vector type with matching scalar element type: 'float2' (vector of 2 'float' values) vs 'half' (aka '_Float16')}}
3636
}
3737

3838
float2 test_mismatch_vector_size_inputs(float2 p0, float3 p1, float p2) {

0 commit comments

Comments
 (0)