Skip to content

Commit f074649

Browse files
committed
add a target lowering pass through
1 parent fb81c1d commit f074649

File tree

17 files changed

+623
-88
lines changed

17 files changed

+623
-88
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4810,6 +4810,12 @@ def HLSLDegrees : LangBuiltin<"HLSL_LANG"> {
48104810
let Prototype = "void(...)";
48114811
}
48124812

4813+
def HLSLDistance : LangBuiltin<"HLSL_LANG"> {
4814+
let Spellings = ["__builtin_hlsl_distance"];
4815+
let Attributes = [NoThrow, Const];
4816+
let Prototype = "void(...)";
4817+
}
4818+
48134819
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
48144820
let Spellings = ["__builtin_hlsl_dot"];
48154821
let Attributes = [NoThrow, Const];
@@ -4846,6 +4852,12 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
48464852
let Prototype = "void(...)";
48474853
}
48484854

4855+
def HLSLLength : LangBuiltin<"HLSL_LANG"> {
4856+
let Spellings = ["__builtin_hlsl_length"];
4857+
let Attributes = [NoThrow, Const];
4858+
let Prototype = "void(...)";
4859+
}
4860+
48494861
def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
48504862
let Spellings = ["__builtin_hlsl_lerp"];
48514863
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19097,6 +19097,20 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1909719097
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
1909819098
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
1909919099
}
19100+
case Builtin::BI__builtin_hlsl_distance: {
19101+
Value *X = EmitScalarExpr(E->getArg(0));
19102+
Value *Y = EmitScalarExpr(E->getArg(1));
19103+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
19104+
E->getArg(1)->getType()->hasFloatingRepresentation() &&
19105+
"Distance operands must have a float representation");
19106+
assert(E->getArg(0)->getType()->isVectorType() &&
19107+
E->getArg(1)->getType()->isVectorType() &&
19108+
"Distance operands must be a vector");
19109+
return Builder.CreateIntrinsic(
19110+
/*ReturnType=*/X->getType()->getScalarType(),
19111+
CGM.getHLSLRuntime().getDistanceIntrinsic(), ArrayRef<Value *>{X, Y},
19112+
nullptr, "hlsl.distance");
19113+
}
1910019114
case Builtin::BI__builtin_hlsl_dot: {
1910119115
Value *Op0 = EmitScalarExpr(E->getArg(0));
1910219116
Value *Op1 = EmitScalarExpr(E->getArg(1));
@@ -19174,6 +19188,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1917419188
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
1917519189
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
1917619190
}
19191+
case Builtin::BI__builtin_hlsl_length: {
19192+
Value *X = EmitScalarExpr(E->getArg(0));
19193+
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
19194+
"length operand must have a float representation");
19195+
assert(E->getArg(0)->getType()->isVectorType() &&
19196+
"length operand must be a vector");
19197+
return Builder.CreateIntrinsic(
19198+
/*ReturnType=*/X->getType()->getScalarType(),
19199+
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
19200+
nullptr, "hlsl.length");
19201+
}
1917719202
case Builtin::BI__builtin_hlsl_normalize: {
1917819203
Value *X = EmitScalarExpr(E->getArg(0));
1917919204

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,36 @@
3030
#include <optional>
3131
#include <vector>
3232

33+
#define GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(FunctionName, \
34+
IntrinsicPostfix) \
35+
GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix, 1, 1)
36+
3337
// A function generator macro for picking the right intrinsic
3438
// for the target backend
35-
#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix) \
39+
#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix, \
40+
IncludeDXIL, IncludeSPIRV) \
3641
llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \
3742
llvm::Triple::ArchType Arch = getArch(); \
3843
switch (Arch) { \
39-
case llvm::Triple::dxil: \
40-
return llvm::Intrinsic::dx_##IntrinsicPostfix; \
41-
case llvm::Triple::spirv: \
42-
return llvm::Intrinsic::spv_##IntrinsicPostfix; \
44+
/* Include DXIL case only if IncludeDXIL is true */ \
45+
IF_INCLUDE(IncludeDXIL, case llvm::Triple::dxil \
46+
: return llvm::Intrinsic::dx_##IntrinsicPostfix;) \
47+
/* Include SPIRV case only if IncludeSPIRV is true */ \
48+
IF_INCLUDE(IncludeSPIRV, case llvm::Triple::spirv \
49+
: return llvm::Intrinsic::spv_##IntrinsicPostfix;) \
50+
\
4351
default: \
4452
llvm_unreachable("Intrinsic " #IntrinsicPostfix \
4553
" not supported by target architecture"); \
4654
} \
4755
}
4856

57+
#define IF_INCLUDE(Condition, Code) IF_INCLUDE_IMPL(Condition, Code)
58+
#define IF_INCLUDE_IMPL(Condition, Code) IF_INCLUDE_##Condition(Code)
59+
60+
#define IF_INCLUDE_1(Code) Code
61+
#define IF_INCLUDE_0(Code)
62+
4963
namespace llvm {
5064
class GlobalVariable;
5165
class Function;
@@ -72,36 +86,43 @@ class CGHLSLRuntime {
7286
// Start of reserved area for HLSL intrinsic getters.
7387
//===----------------------------------------------------------------------===//
7488

75-
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
76-
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
77-
GENERATE_HLSL_INTRINSIC_FUNCTION(Cross, cross)
78-
GENERATE_HLSL_INTRINSIC_FUNCTION(Degrees, degrees)
79-
GENERATE_HLSL_INTRINSIC_FUNCTION(Frac, frac)
80-
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
81-
GENERATE_HLSL_INTRINSIC_FUNCTION(Normalize, normalize)
82-
GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
83-
GENERATE_HLSL_INTRINSIC_FUNCTION(Saturate, saturate)
84-
GENERATE_HLSL_INTRINSIC_FUNCTION(Sign, sign)
85-
GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step)
86-
GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians)
87-
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
88-
GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
89-
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
90-
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
91-
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
92-
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
93-
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
94-
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
95-
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
96-
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
97-
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
98-
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
99-
GENERATE_HLSL_INTRINSIC_FUNCTION(NClamp, nclamp)
100-
GENERATE_HLSL_INTRINSIC_FUNCTION(SClamp, sclamp)
101-
GENERATE_HLSL_INTRINSIC_FUNCTION(UClamp, uclamp)
102-
103-
GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
104-
GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
89+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(All, all)
90+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Any, any)
91+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Cross, cross)
92+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Degrees, degrees)
93+
GENERATE_HLSL_INTRINSIC_FUNCTION(Distance, distance, /*IncludeDXIL*/ 0,
94+
/*IncludeSPIRV*/ 1)
95+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Frac, frac)
96+
GENERATE_HLSL_INTRINSIC_FUNCTION(Length, length, /*IncludeDXIL*/ 0,
97+
/*IncludeSPIRV*/ 1)
98+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Lerp, lerp)
99+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Normalize, normalize)
100+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Rsqrt, rsqrt)
101+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Saturate, saturate)
102+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Sign, sign)
103+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Step, step)
104+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Radians, radians)
105+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(ThreadId, thread_id)
106+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(FDot, fdot)
107+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(SDot, sdot)
108+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(UDot, udot)
109+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Dot4AddI8Packed, dot4add_i8packed)
110+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(Dot4AddU8Packed, dot4add_u8packed)
111+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(WaveActiveAnyTrue, wave_any)
112+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(WaveActiveCountBits,
113+
wave_active_countbits)
114+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(WaveIsFirstLane, wave_is_first_lane)
115+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(WaveReadLaneAt, wave_readlane)
116+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(FirstBitUHigh, firstbituhigh)
117+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(FirstBitSHigh, firstbitshigh)
118+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(NClamp, nclamp)
119+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(SClamp, sclamp)
120+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(UClamp, uclamp)
121+
122+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(CreateHandleFromBinding,
123+
handle_fromBinding)
124+
GENERATE_HLSL_INTRINSIC_FUNCTION_DEFAULT(BufferUpdateCounter,
125+
bufferUpdateCounter)
105126

106127
//===----------------------------------------------------------------------===//
107128
// End of reserved area for HLSL intrinsic getters.

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,34 @@
99
#ifndef _HLSL_HLSL_DETAILS_H_
1010
#define _HLSL_HLSL_DETAILS_H_
1111

12+
#if __is_target_arch(dxil)
13+
#define IS_ARCH_DXIL 1
14+
#else
15+
#define IS_ARCH_DXIL 0
16+
#endif
17+
18+
#if __is_target_arch(spirv)
19+
#define IS_ARCH_SPIRV 1
20+
#else
21+
#define IS_ARCH_SPIRV 0
22+
#endif
23+
24+
#define ARCH_CONDITION(arch) \
25+
if (IS_ARCH_##arch) \
26+
return true;
27+
28+
// Note: arch is used to bypass
29+
// the generic implementation
30+
#define EXPAND_ARCH_CONDITIONS(arch) \
31+
ARCH_CONDITION(arch) \
32+
/* Add more architectures as needed */
33+
34+
#define DEFINE_TARGET_LOWERING(function_name, ...) \
35+
constexpr bool Has##function_name##Lowering() { \
36+
EXPAND_ARCH_CONDITIONS(__VA_ARGS__) \
37+
return false; /* Default case if no match */ \
38+
}
39+
1240
namespace hlsl {
1341

1442
namespace __detail {
@@ -41,6 +69,7 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
4169
return __builtin_bit_cast(U, F);
4270
}
4371

72+
DEFINE_TARGET_LOWERING(Length, SPIRV)
4473
template <typename T>
4574
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
4675
length_impl(T X) {
@@ -50,15 +79,14 @@ length_impl(T X) {
5079
template <typename T, int N>
5180
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
5281
length_vec_impl(vector<T, N> X) {
82+
if (HasLengthLowering())
83+
return __builtin_hlsl_length(X);
5384
vector<T, N> XSquared = X * X;
5485
T XSquaredSum = __builtin_hlsl_reduce_add(XSquared);
55-
/*T XSquaredSum = 0;
56-
for(int I = 0; I < N; I++) {
57-
XSquaredSum += XSquared[I];
58-
}*/
5986
return __builtin_elementwise_sqrt(XSquaredSum);
6087
}
6188

89+
DEFINE_TARGET_LOWERING(Distance, SPIRV)
6290
template <typename T>
6391
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
6492
distance_impl(T X, T Y) {
@@ -68,6 +96,8 @@ distance_impl(T X, T Y) {
6896
template <typename T, int N>
6997
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
7098
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
99+
if (HasDistanceLowering())
100+
return __builtin_hlsl_distance(X, Y);
71101
return length_vec_impl(X - Y);
72102
}
73103

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -879,22 +879,21 @@ float4 degrees(float4);
879879
/// \brief Returns a distance scalar between two vectors of \a X and \a Y.
880880
/// \param X The X input value.
881881
/// \param Y The Y input value.
882-
883882
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
884883
const inline half distance(half X, half Y) {
885884
return __detail::distance_impl(X, Y);
886885
}
887886

887+
const inline float distance(float X, float Y) {
888+
return __detail::distance_impl(X, Y);
889+
}
890+
888891
template <int N>
889892
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
890893
const inline half distance(vector<half, N> X, vector<half, N> Y) {
891894
return __detail::distance_vec_impl(X, Y);
892895
}
893896

894-
const inline float distance(float X, float Y) {
895-
return __detail::distance_impl(X, Y);
896-
}
897-
898897
template <int N>
899898
const inline float distance(vector<float, N> X, vector<float, N> Y) {
900899
return __detail::distance_vec_impl(X, Y);
@@ -1325,18 +1324,16 @@ float4 lerp(float4, float4, float4);
13251324
/// \param x [in] The vector of floats, or a scalar float.
13261325
///
13271326
/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
1328-
13291327
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
13301328
const inline half length(half X) { return __detail::length_impl(X); }
1329+
const inline float length(float X) { return __detail::length_impl(X); }
13311330

13321331
template <int N>
13331332
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
13341333
const inline half length(vector<half, N> X) {
13351334
return __detail::length_vec_impl(X);
13361335
}
13371336

1338-
const inline float length(float X) { return __detail::length_impl(X); }
1339-
13401337
template <int N> const inline float length(vector<float, N> X) {
13411338
return __detail::length_vec_impl(X);
13421339
}

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,6 +1976,20 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19761976
TheCall->setType(ArgTyA);
19771977
break;
19781978
}
1979+
case Builtin::BI__builtin_hlsl_distance: {
1980+
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
1981+
return true;
1982+
if (SemaRef.checkArgCount(TheCall, 2))
1983+
return true;
1984+
if (CheckVectorElementCallArgs(&SemaRef, TheCall))
1985+
return true;
1986+
ExprResult A = TheCall->getArg(0);
1987+
QualType ArgTyA = A.get()->getType();
1988+
auto *VTy = ArgTyA->getAs<VectorType>();
1989+
QualType RetTy = VTy->getElementType();
1990+
TheCall->setType(RetTy);
1991+
break;
1992+
}
19791993
case Builtin::BI__builtin_hlsl_dot: {
19801994
if (SemaRef.checkArgCount(TheCall, 2))
19811995
return true;
@@ -2063,18 +2077,23 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
20632077
return true;
20642078
break;
20652079
}
2066-
case Builtin::BI__builtin_hlsl_reduce_add: {
2080+
case Builtin::BI__builtin_hlsl_length: {
2081+
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2082+
return true;
20672083
if (SemaRef.checkArgCount(TheCall, 1))
20682084
return true;
20692085
ExprResult A = TheCall->getArg(0);
20702086
QualType ArgTyA = A.get()->getType();
2071-
QualType RetTy;
2072-
2073-
if (auto *VTy = ArgTyA->getAs<VectorType>())
2074-
RetTy = VTy->getElementType();
2075-
else
2076-
RetTy = TheCall->getArg(0)->getType();
2077-
2087+
auto *VTy = ArgTyA->getAs<VectorType>();
2088+
if (VTy == nullptr) {
2089+
SemaRef.Diag(A.get()->getBeginLoc(),
2090+
diag::err_typecheck_convert_incompatible)
2091+
<< ArgTyA
2092+
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
2093+
<< 0 << 0;
2094+
return true;
2095+
}
2096+
QualType RetTy = VTy->getElementType();
20782097
TheCall->setType(RetTy);
20792098
break;
20802099
}
@@ -2101,6 +2120,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
21012120
TheCall->setType(ArgTyA);
21022121
break;
21032122
}
2123+
case Builtin::BI__builtin_hlsl_reduce_add: {
2124+
if (SemaRef.checkArgCount(TheCall, 1))
2125+
return true;
2126+
ExprResult A = TheCall->getArg(0);
2127+
QualType ArgTyA = A.get()->getType();
2128+
QualType RetTy;
2129+
2130+
if (auto *VTy = ArgTyA->getAs<VectorType>())
2131+
RetTy = VTy->getElementType();
2132+
else
2133+
RetTy = TheCall->getArg(0)->getType();
2134+
2135+
TheCall->setType(RetTy);
2136+
break;
2137+
}
21042138
case Builtin::BI__builtin_hlsl_elementwise_sign: {
21052139
if (CheckFloatingOrIntRepresentation(&SemaRef, TheCall))
21062140
return true;

0 commit comments

Comments
 (0)