Skip to content

Commit 03b137a

Browse files
committed
[HLSL] Move length to the header
1 parent 9eb7438 commit 03b137a

File tree

12 files changed

+367
-263
lines changed

12 files changed

+367
-263
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4756,6 +4756,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
47564756
let Prototype = "void(...)";
47574757
}
47584758

4759+
def HLSLReduceAdd : LangBuiltin<"HLSL_LANG"> {
4760+
let Spellings = ["__builtin_hlsl_reduce_add"];
4761+
let Attributes = [NoThrow, Const, Constexpr];
4762+
let Prototype = "void(...)";
4763+
}
4764+
47594765
def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
47604766
let Spellings = ["__builtin_hlsl_wave_active_any_true"];
47614767
let Attributes = [NoThrow, Const];
@@ -4840,12 +4846,6 @@ def HLSLIsinf : LangBuiltin<"HLSL_LANG"> {
48404846
let Prototype = "void(...)";
48414847
}
48424848

4843-
def HLSLLength : LangBuiltin<"HLSL_LANG"> {
4844-
let Spellings = ["__builtin_hlsl_length"];
4845-
let Attributes = [NoThrow, Const];
4846-
let Prototype = "void(...)";
4847-
}
4848-
48494849
def HLSLLerp : LangBuiltin<"HLSL_LANG"> {
48504850
let Spellings = ["__builtin_hlsl_lerp"];
48514851
let Attributes = [NoThrow, Const];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19174,20 +19174,6 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1917419174
/*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
1917519175
ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
1917619176
}
19177-
case Builtin::BI__builtin_hlsl_length: {
19178-
Value *X = EmitScalarExpr(E->getArg(0));
19179-
19180-
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
19181-
"length operand must have a float representation");
19182-
// if the operand is a scalar, we can use the fabs llvm intrinsic directly
19183-
if (!E->getArg(0)->getType()->isVectorType())
19184-
return EmitFAbs(*this, X);
19185-
19186-
return Builder.CreateIntrinsic(
19187-
/*ReturnType=*/X->getType()->getScalarType(),
19188-
CGM.getHLSLRuntime().getLengthIntrinsic(), ArrayRef<Value *>{X},
19189-
nullptr, "hlsl.length");
19190-
}
1919119177
case Builtin::BI__builtin_hlsl_normalize: {
1919219178
Value *X = EmitScalarExpr(E->getArg(0));
1919319179

@@ -19289,6 +19275,21 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1928919275
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
1929019276
nullptr, "hlsl.saturate");
1929119277
}
19278+
case Builtin::BI__builtin_hlsl_reduce_add: {
19279+
Value *X = EmitScalarExpr(E->getArg(0));
19280+
auto EltTy = X->getType()->getScalarType();
19281+
if (E->getArg(0)->getType()->hasFloatingRepresentation()) {
19282+
Value *Seed = ConstantFP::get(EltTy, 0);
19283+
return Builder.CreateIntrinsic(
19284+
/*ReturnType=*/EltTy, llvm::Intrinsic::vector_reduce_fadd,
19285+
ArrayRef<Value *>{Seed, X}, nullptr, "rdx.fadd");
19286+
} else {
19287+
assert(E->getArg(0)->getType()->hasIntegerRepresentation());
19288+
return Builder.CreateIntrinsic(
19289+
/*ReturnType=*/EltTy, llvm::Intrinsic::vector_reduce_add,
19290+
ArrayRef<Value *>{X}, nullptr, "rdx.add");
19291+
}
19292+
}
1929219293
case Builtin::BI__builtin_hlsl_select: {
1929319294
Value *OpCond = EmitScalarExpr(E->getArg(0));
1929419295
RValue RValTrue = EmitAnyExpr(E->getArg(1));

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,34 @@ constexpr enable_if_t<sizeof(U) == sizeof(T), U> bit_cast(T F) {
4141
return __builtin_bit_cast(U, F);
4242
}
4343

44+
template <typename T>
45+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
46+
length_impl(T X) {
47+
return __builtin_elementwise_abs(X);
48+
}
49+
50+
template <typename T, int N>
51+
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
52+
length_vec_impl(vector<T, N> X) {
53+
vector<T, N> XSquared = X * X;
54+
T XSquaredSum = __builtin_hlsl_reduce_add(XSquared);
55+
/*T XSquaredSum = 0;
56+
for(int I = 0; I < N; I++) {
57+
XSquaredSum += XSquared[I];
58+
}*/
59+
return __builtin_elementwise_sqrt(XSquaredSum);
60+
}
61+
4462
template <typename T>
4563
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
4664
distance_impl(T X, T Y) {
47-
return __builtin_elementwise_abs(X - Y);
65+
return length_impl(X - Y);
4866
}
4967

5068
template <typename T, int N>
5169
constexpr enable_if_t<is_same<float, T>::value || is_same<half, T>::value, T>
5270
distance_vec_impl(vector<T, N> X, vector<T, N> Y) {
53-
return __builtin_hlsl_length(X - Y);
71+
return length_vec_impl(X - Y);
5472
}
5573

5674
} // namespace __detail

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,26 +1327,19 @@ float4 lerp(float4, float4, float4);
13271327
/// Length is based on the following formula: sqrt(x[0]^2 + x[1]^2 + ...).
13281328

13291329
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1330-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1331-
half length(half);
1332-
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1333-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1334-
half length(half2);
1335-
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1336-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1337-
half length(half3);
1338-
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1339-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1340-
half length(half4);
1341-
1342-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1343-
float length(float);
1344-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1345-
float length(float2);
1346-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1347-
float length(float3);
1348-
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_length)
1349-
float length(float4);
1330+
const inline half length(half X) { return __detail::length_impl(X); }
1331+
1332+
template <int N>
1333+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
1334+
const inline half length(vector<half, N> X) {
1335+
return __detail::length_vec_impl(X);
1336+
}
1337+
1338+
const inline float length(float X) { return __detail::length_impl(X); }
1339+
1340+
template <int N> const inline float length(vector<float, N> X) {
1341+
return __detail::length_vec_impl(X);
1342+
}
13501343

13511344
//===----------------------------------------------------------------------===//
13521345
// log builtins

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,12 +2063,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
20632063
return true;
20642064
break;
20652065
}
2066-
case Builtin::BI__builtin_hlsl_length: {
2067-
if (CheckFloatOrHalfRepresentations(&SemaRef, TheCall))
2068-
return true;
2066+
case Builtin::BI__builtin_hlsl_reduce_add: {
20692067
if (SemaRef.checkArgCount(TheCall, 1))
20702068
return true;
2071-
20722069
ExprResult A = TheCall->getArg(0);
20732070
QualType ArgTyA = A.get()->getType();
20742071
QualType RetTy;

clang/test/CodeGenHLSL/builtins/distance.hlsl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,32 @@ half test_distance_half(half X, half Y) { return distance(X, Y); }
1616
// CHECK-SAME: <2 x half> noundef [[X:%.*]], <2 x half> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
1717
// CHECK-NEXT: [[ENTRY:.*:]]
1818
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <2 x half> [[X]], [[Y]]
19-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef half @llvm.dx.length.v2f16(<2 x half> [[SUB_I]])
20-
// CHECK-NEXT: ret half [[HLSL_LENGTH_I]]
19+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <2 x half> [[SUB_I]], [[SUB_I]]
20+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call half @llvm.vector.reduce.fadd.v2f16(half 0xH0000, <2 x half> [[MUL_I]])
21+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef half @llvm.sqrt.f16(half [[RDX_FADD_I]])
22+
// CHECK-NEXT: ret half [[TMP0]]
2123
//
2224
half test_distance_half2(half2 X, half2 Y) { return distance(X, Y); }
2325

2426
// CHECK-LABEL: define noundef half @_Z19test_distance_half3Dv3_DhS_(
2527
// CHECK-SAME: <3 x half> noundef [[X:%.*]], <3 x half> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
2628
// CHECK-NEXT: [[ENTRY:.*:]]
2729
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <3 x half> [[X]], [[Y]]
28-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef half @llvm.dx.length.v3f16(<3 x half> [[SUB_I]])
29-
// CHECK-NEXT: ret half [[HLSL_LENGTH_I]]
30+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <3 x half> [[SUB_I]], [[SUB_I]]
31+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call half @llvm.vector.reduce.fadd.v3f16(half 0xH0000, <3 x half> [[MUL_I]])
32+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef half @llvm.sqrt.f16(half [[RDX_FADD_I]])
33+
// CHECK-NEXT: ret half [[TMP0]]
3034
//
3135
half test_distance_half3(half3 X, half3 Y) { return distance(X, Y); }
3236

3337
// CHECK-LABEL: define noundef half @_Z19test_distance_half4Dv4_DhS_(
3438
// CHECK-SAME: <4 x half> noundef [[X:%.*]], <4 x half> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
3539
// CHECK-NEXT: [[ENTRY:.*:]]
3640
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <4 x half> [[X]], [[Y]]
37-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef half @llvm.dx.length.v4f16(<4 x half> [[SUB_I]])
38-
// CHECK-NEXT: ret half [[HLSL_LENGTH_I]]
41+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <4 x half> [[SUB_I]], [[SUB_I]]
42+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call half @llvm.vector.reduce.fadd.v4f16(half 0xH0000, <4 x half> [[MUL_I]])
43+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef half @llvm.sqrt.f16(half [[RDX_FADD_I]])
44+
// CHECK-NEXT: ret half [[TMP0]]
3945
//
4046
half test_distance_half4(half4 X, half4 Y) { return distance(X, Y); }
4147

@@ -52,25 +58,31 @@ float test_distance_float(float X, float Y) { return distance(X, Y); }
5258
// CHECK-SAME: <2 x float> noundef [[X:%.*]], <2 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
5359
// CHECK-NEXT: [[ENTRY:.*:]]
5460
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <2 x float> [[X]], [[Y]]
55-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef float @llvm.dx.length.v2f32(<2 x float> [[SUB_I]])
56-
// CHECK-NEXT: ret float [[HLSL_LENGTH_I]]
61+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <2 x float> [[SUB_I]], [[SUB_I]]
62+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[MUL_I]])
63+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef float @llvm.sqrt.f32(float [[RDX_FADD_I]])
64+
// CHECK-NEXT: ret float [[TMP0]]
5765
//
5866
float test_distance_float2(float2 X, float2 Y) { return distance(X, Y); }
5967

6068
// CHECK-LABEL: define noundef float @_Z20test_distance_float3Dv3_fS_(
6169
// CHECK-SAME: <3 x float> noundef [[X:%.*]], <3 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
6270
// CHECK-NEXT: [[ENTRY:.*:]]
6371
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <3 x float> [[X]], [[Y]]
64-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef float @llvm.dx.length.v3f32(<3 x float> [[SUB_I]])
65-
// CHECK-NEXT: ret float [[HLSL_LENGTH_I]]
72+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <3 x float> [[SUB_I]], [[SUB_I]]
73+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call float @llvm.vector.reduce.fadd.v3f32(float 0.000000e+00, <3 x float> [[MUL_I]])
74+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef float @llvm.sqrt.f32(float [[RDX_FADD_I]])
75+
// CHECK-NEXT: ret float [[TMP0]]
6676
//
6777
float test_distance_float3(float3 X, float3 Y) { return distance(X, Y); }
6878

6979
// CHECK-LABEL: define noundef float @_Z20test_distance_float4Dv4_fS_(
7080
// CHECK-SAME: <4 x float> noundef [[X:%.*]], <4 x float> noundef [[Y:%.*]]) local_unnamed_addr #[[ATTR0]] {
7181
// CHECK-NEXT: [[ENTRY:.*:]]
7282
// CHECK-NEXT: [[SUB_I:%.*]] = fsub <4 x float> [[X]], [[Y]]
73-
// CHECK-NEXT: [[HLSL_LENGTH_I:%.*]] = tail call noundef float @llvm.dx.length.v4f32(<4 x float> [[SUB_I]])
74-
// CHECK-NEXT: ret float [[HLSL_LENGTH_I]]
83+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul <4 x float> [[SUB_I]], [[SUB_I]]
84+
// CHECK-NEXT: [[RDX_FADD_I:%.*]] = tail call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> [[MUL_I]])
85+
// CHECK-NEXT: [[TMP0:%.*]] = tail call noundef float @llvm.sqrt.f32(float [[RDX_FADD_I]])
86+
// CHECK-NEXT: ret float [[TMP0]]
7587
//
7688
float test_distance_float4(float4 X, float4 Y) { return distance(X, Y); }

0 commit comments

Comments
 (0)