Skip to content

Commit 53541ff

Browse files
committed
Modify ternary operator in complex and add it as a functional struct in functional.hlsl . Small change to FFT Indexing utils
1 parent ea31887 commit 53541ff

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

include/nbl/builtin/hlsl/complex.hlsl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#define _NBL_BUILTIN_HLSL_COMPLEX_INCLUDED_
77

88
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
9+
#include <nbl/builtin/hlsl/functional.hlsl>
10+
11+
using namespace nbl::hlsl;
912

1013
// -------------------------------------- CPP VERSION ------------------------------------
1114
#ifndef __HLSL_VERSION
@@ -44,8 +47,6 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
4447
// -------------------------------------- HLSL VERSION ---------------------------------------
4548
#else
4649

47-
#include "nbl/builtin/hlsl/functional.hlsl"
48-
4950
namespace nbl
5051
{
5152
namespace hlsl
@@ -164,6 +165,8 @@ struct complex_t
164165
template<typename Scalar>
165166
struct plus< complex_t<Scalar> >
166167
{
168+
using type_t = complex_t<Scalar>;
169+
167170
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
168171
{
169172
return lhs + rhs;
@@ -175,6 +178,8 @@ struct plus< complex_t<Scalar> >
175178
template<typename Scalar>
176179
struct minus< complex_t<Scalar> >
177180
{
181+
using type_t = complex_t<Scalar>;
182+
178183
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
179184
{
180185
return lhs - rhs;
@@ -186,6 +191,8 @@ struct minus< complex_t<Scalar> >
186191
template<typename Scalar>
187192
struct multiplies< complex_t<Scalar> >
188193
{
194+
using type_t = complex_t<Scalar>;
195+
189196
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
190197
{
191198
return lhs * rhs;
@@ -202,6 +209,8 @@ struct multiplies< complex_t<Scalar> >
202209
template<typename Scalar>
203210
struct divides< complex_t<Scalar> >
204211
{
212+
using type_t = complex_t<Scalar>;
213+
205214
complex_t<Scalar> operator()(NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
206215
{
207216
return lhs / rhs;
@@ -417,17 +426,20 @@ complex_t<Scalar> rotateRight(NBL_CONST_REF_ARG(complex_t<Scalar>) value)
417426
return retVal;
418427
}
419428

420-
// Annoyed at having to write a lot of boilerplate to do a select
421-
// Essentially returns what you'd expect from doing `condition ? a : b`
422429
template<typename Scalar>
423-
complex_t<Scalar> ternaryOperator(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) a, NBL_CONST_REF_ARG(complex_t<Scalar>) b)
430+
struct ternary_operator< complex_t<Scalar> >
424431
{
425-
const vector<Scalar, 2> aVector = vector<Scalar, 2>(a.real(), a.imag());
426-
const vector<Scalar, 2> bVector = vector<Scalar, 2>(b.real(), b.imag());
427-
const vector<Scalar, 2> resultVector = condition ? aVector : bVector;
428-
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
429-
return result;
430-
}
432+
using type_t = complex_t<Scalar>;
433+
434+
complex_t<Scalar> operator()(bool condition, NBL_CONST_REF_ARG(complex_t<Scalar>) lhs, NBL_CONST_REF_ARG(complex_t<Scalar>) rhs)
435+
{
436+
const vector<Scalar, 2> lhsVector = vector<Scalar, 2>(lhs.real(), lhs.imag());
437+
const vector<Scalar, 2> rhsVector = vector<Scalar, 2>(rhs.real(), rhs.imag());
438+
const vector<Scalar, 2> resultVector = condition ? lhsVector : rhsVector;
439+
const complex_t<Scalar> result = { resultVector.x, resultVector.y };
440+
return result;
441+
}
442+
};
431443

432444

433445
}

include/nbl/builtin/hlsl/functional.hlsl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ COMPOUND_ASSIGN(divides)
165165

166166
// ----------------- End of compound assignment ops ----------------
167167

168-
// Min and Max don't use ALIAS_STD because they don't exist in STD
168+
// Min, Max and Ternary Operator don't use ALIAS_STD because they don't exist in STD
169169
// TODO: implement as mix(rhs<lhs,lhs,rhs) (SPIR-V intrinsic from the extended set & glm on C++)
170170
template<typename T>
171171
struct minimum
@@ -195,6 +195,17 @@ struct maximum
195195
NBL_CONSTEXPR_STATIC_INLINE T identity = numeric_limits<scalar_t>::lowest; // TODO: `all_components<T>`
196196
};
197197

198+
template<typename T>
199+
struct ternary_operator
200+
{
201+
using type_t = T;
202+
203+
T operator()(bool condition, NBL_CONST_REF_ARG(T) lhs, NBL_CONST_REF_ARG(T) rhs)
204+
{
205+
return condition ? lhs : rhs;
206+
}
207+
};
208+
198209
}
199210
}
200211

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ struct FFTIndexingUtils
200200
uint32_t mirrorLocalIndex;
201201
};
202202

203-
static NablaMirrorLocalInfo getNablaMirrorLocalInfo(uint32_t localElementIndex)
203+
static NablaMirrorLocalInfo getNablaMirrorLocalInfo(uint32_t globalElementIndex)
204204
{
205-
const uint32_t globalElementIndex = localElementIndex * WorkgroupSize | workgroup::SubgroupContiguousIndex();
206205
const uint32_t otherElementIndex = FFTIndexingUtils::getNablaMirrorIndex(globalElementIndex);
207206
const uint32_t mirrorLocalIndex = otherElementIndex / WorkgroupSize;
208207
const uint32_t otherThreadID = otherElementIndex & (WorkgroupSize - 1);

0 commit comments

Comments
 (0)