Skip to content

Commit 056f0a1

Browse files
authored
[HLSL][DXIL] Implement refract intrinsic (#147342)
- [x] Implement refract using HLSL source in hlsl_intrinsics.h - [x] Implement the refract SPIR-V target built-in in clang/include/clang/Basic/BuiltinsSPIRV.td - [x] Add sema checks for refract to CheckSPIRVBuiltinFunctionCall in clang/lib/Sema/SemaSPIRV.cpp - [x] Add codegen for spv refract to EmitSPIRVBuiltinExpr in CGBuiltin.cpp - [x] Add codegen tests to clang/test/CodeGenHLSL/builtins/refract.hlsl - [x] Add spv codegen test to clang/test/CodeGenSPIRV/Builtins/refract.c - [x] Add sema tests to clang/test/SemaHLSL/BuiltIns/refract-errors.hlsl - [x] Add spv sema tests to clang/test/SemaSPIRV/BuiltIns/refract-errors.c - [x] Create the int_spv_refract intrinsic in IntrinsicsSPIRV.td - [x] In SPIRVInstructionSelector.cpp create the refract lowering and map it to int_spv_refract in SPIRVInstructionSelector::selectIntrinsic. - [x] Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/refract.ll - [x] Check for what OpenCL support is needed. Resolves #99153
1 parent 22994ed commit 056f0a1

File tree

13 files changed

+643
-1
lines changed

13 files changed

+643
-1
lines changed

clang/include/clang/Basic/BuiltinsSPIRVVK.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
1111

1212
def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
1313
def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
14+
def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;

clang/lib/CodeGen/TargetBuiltins/SPIR.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,18 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
5858
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
5959
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
6060
}
61+
case SPIRV::BI__builtin_spirv_refract: {
62+
Value *I = EmitScalarExpr(E->getArg(0));
63+
Value *N = EmitScalarExpr(E->getArg(1));
64+
Value *eta = EmitScalarExpr(E->getArg(2));
65+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
66+
E->getArg(1)->getType()->hasFloatingRepresentation() &&
67+
E->getArg(2)->getType()->isFloatingType() &&
68+
"refract operands must have a float representation");
69+
return Builder.CreateIntrinsic(
70+
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
71+
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
72+
}
6173
case SPIRV::BI__builtin_spirv_smoothstep: {
6274
Value *Min = EmitScalarExpr(E->getArg(0));
6375
Value *Max = EmitScalarExpr(E->getArg(1));

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
7171
#endif
7272
}
7373

74+
template <typename T, typename U> constexpr T refract_impl(T I, T N, U Eta) {
75+
#if (__has_builtin(__builtin_spirv_refract))
76+
return __builtin_spirv_refract(I, N, Eta);
77+
#endif
78+
T Mul = dot(N, I);
79+
T K = 1 - Eta * Eta * (1 - Mul * Mul);
80+
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
81+
return select<T>(K < 0, static_cast<T>(0), Result);
82+
}
83+
7484
template <typename T> constexpr T fmod_impl(T X, T Y) {
7585
#if !defined(__DIRECTX__)
7686
return __builtin_elementwise_fmod(X, Y);

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
475475
return __detail::reflect_vec_impl(I, N);
476476
}
477477

478+
//===----------------------------------------------------------------------===//
479+
// refract builtin
480+
//===----------------------------------------------------------------------===//
481+
482+
/// \fn T refract(T I, T N, T eta)
483+
/// \brief Returns a refraction using an entering ray, \a I, a surface
484+
/// normal, \a N and refraction index \a eta
485+
/// \param I The entering ray.
486+
/// \param N The surface normal.
487+
/// \param eta The refraction index.
488+
///
489+
/// The return value is a floating-point vector that represents the refraction
490+
/// using the refraction index, \a eta, for the direction of the entering ray,
491+
/// \a I, off a surface with the normal \a N.
492+
///
493+
/// This function calculates the refraction vector using the following formulas:
494+
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
495+
/// if k < 0.0 the result is 0.0
496+
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
497+
///
498+
/// I and N must already be normalized in order to achieve the desired result.
499+
///
500+
/// I and N must be a scalar or vector whose component type is
501+
/// floating-point.
502+
///
503+
/// eta must be a 16-bit or 32-bit floating-point scalar.
504+
///
505+
/// Result type, the type of I, and the type of N must all be the same type.
506+
507+
template <typename T>
508+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
509+
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
510+
__detail::is_same<half, T>::value,
511+
T> refract(T I, T N, T eta) {
512+
return __detail::refract_impl(I, N, eta);
513+
}
514+
515+
template <typename T>
516+
const inline __detail::enable_if_t<
517+
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
518+
refract(T I, T N, T eta) {
519+
return __detail::refract_impl(I, N, eta);
520+
}
521+
522+
template <int L>
523+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
524+
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
525+
__detail::HLSL_FIXED_VECTOR<half, L> I,
526+
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
527+
return __detail::refract_impl(I, N, eta);
528+
}
529+
530+
template <int L>
531+
const inline __detail::HLSL_FIXED_VECTOR<float, L>
532+
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
533+
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
534+
return __detail::refract_impl(I, N, eta);
535+
}
536+
478537
//===----------------------------------------------------------------------===//
479538
// smoothstep builtin
480539
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,49 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
4646
return false;
4747
}
4848

49+
static bool CheckAllArgTypesAreCorrect(
50+
Sema *S, CallExpr *TheCall,
51+
llvm::ArrayRef<
52+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
53+
Checks) {
54+
unsigned NumArgs = TheCall->getNumArgs();
55+
assert(Checks.size() == NumArgs &&
56+
"Wrong number of checks for Number of args.");
57+
// Apply each check to the corresponding argument
58+
for (unsigned I = 0; I < NumArgs; ++I) {
59+
Expr *Arg = TheCall->getArg(I);
60+
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
61+
return true;
62+
}
63+
return false;
64+
}
65+
66+
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
67+
int ArgOrdinal,
68+
clang::QualType PassedType) {
69+
clang::QualType BaseType =
70+
PassedType->isVectorType()
71+
? PassedType->castAs<clang::VectorType>()->getElementType()
72+
: PassedType;
73+
if (!BaseType->isHalfType() && !BaseType->isFloat16Type() &&
74+
!BaseType->isFloat32Type())
75+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
76+
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
77+
<< /* half or float */ 2 << PassedType;
78+
return false;
79+
}
80+
81+
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
82+
int ArgOrdinal,
83+
clang::QualType PassedType) {
84+
if (!PassedType->isHalfType() && !PassedType->isFloat16Type() &&
85+
!PassedType->isFloat32Type())
86+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
87+
<< ArgOrdinal << /* scalar */ 1 << /* no int */ 0
88+
<< /* half or float */ 2 << PassedType;
89+
return false;
90+
}
91+
4992
static std::optional<int>
5093
processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
5194
ExprResult Arg =
@@ -235,6 +278,43 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
235278
TheCall->setType(RetTy);
236279
break;
237280
}
281+
case SPIRV::BI__builtin_spirv_refract: {
282+
if (SemaRef.checkArgCount(TheCall, 3))
283+
return true;
284+
285+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>
286+
ChecksArr[] = {CheckFloatOrHalfRepresentation,
287+
CheckFloatOrHalfRepresentation,
288+
CheckFloatOrHalfScalarRepresentation};
289+
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
290+
llvm::ArrayRef(ChecksArr)))
291+
return true;
292+
// Check that first two arguments are vectors/scalars of the same type
293+
QualType Arg0Type = TheCall->getArg(0)->getType();
294+
if (!SemaRef.getASTContext().hasSameUnqualifiedType(
295+
Arg0Type, TheCall->getArg(1)->getType()))
296+
return SemaRef.Diag(TheCall->getBeginLoc(),
297+
diag::err_vec_builtin_incompatible_vector)
298+
<< TheCall->getDirectCallee() << /* first two */ 0
299+
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
300+
TheCall->getArg(1)->getEndLoc());
301+
302+
// Check that scalar type of 3rd arg is same as base type of first two args
303+
clang::QualType BaseType =
304+
Arg0Type->isVectorType()
305+
? Arg0Type->castAs<clang::VectorType>()->getElementType()
306+
: Arg0Type;
307+
if (!SemaRef.getASTContext().hasSameUnqualifiedType(
308+
BaseType, TheCall->getArg(2)->getType()))
309+
return SemaRef.Diag(TheCall->getBeginLoc(),
310+
diag::err_hlsl_builtin_scalar_vector_mismatch)
311+
<< /* all */ 0 << TheCall->getDirectCallee() << Arg0Type
312+
<< TheCall->getArg(2)->getType();
313+
314+
QualType RetTy = TheCall->getArg(0)->getType();
315+
TheCall->setType(RetTy);
316+
break;
317+
}
238318
case SPIRV::BI__builtin_spirv_smoothstep: {
239319
if (SemaRef.checkArgCount(TheCall, 3))
240320
return true;

0 commit comments

Comments
 (0)