Skip to content

Commit ee33835

Browse files
committed
QOL update to FFT shuffles
1 parent 0a224ec commit ee33835

File tree

2 files changed

+7
-55
lines changed

2 files changed

+7
-55
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,14 @@ namespace fft
2424

2525
// ---------------------------------- Utils -----------------------------------------------
2626
template<typename SharedMemoryAdaptor, typename Scalar>
27-
struct exchangeValues;
28-
29-
template<typename SharedMemoryAdaptor>
30-
struct exchangeValues<SharedMemoryAdaptor, float16_t>
27+
struct exchangeValues
3128
{
32-
static void __call(NBL_REF_ARG(complex_t<float16_t>) lo, NBL_REF_ARG(complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
29+
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
3330
{
3431
const bool topHalf = bool(threadID & stride);
35-
// Pack two halves into a single uint32_t
36-
uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real(), lo.imag()) : float16_t2 (hi.real(), hi.imag()));
37-
shuffleXor<SharedMemoryAdaptor, uint32_t>(toExchange, stride, sharedmemAdaptor);
38-
float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
32+
// Pack into float vector because ternary operator does not support structs
33+
vector<Scalar, 2> exchanged = topHalf ? vector<Scalar, 2>(lo.real(), lo.imag()) : vector<Scalar, 2>(hi.real(), hi.imag());
34+
shuffleXor<SharedMemoryAdaptor, vector<Scalar, 2> >(exchanged, stride, sharedmemAdaptor);
3935
if (topHalf)
4036
{
4137
lo.real(exchanged.x);
@@ -45,51 +41,7 @@ struct exchangeValues<SharedMemoryAdaptor, float16_t>
4541
{
4642
hi.real(exchanged.x);
4743
lo.imag(exchanged.y);
48-
}
49-
}
50-
};
51-
52-
template<typename SharedMemoryAdaptor>
53-
struct exchangeValues<SharedMemoryAdaptor, float32_t>
54-
{
55-
static void __call(NBL_REF_ARG(complex_t<float32_t>) lo, NBL_REF_ARG(complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
56-
{
57-
const bool topHalf = bool(threadID & stride);
58-
// pack into `float32_t2` because ternary operator doesn't support structs
59-
float32_t2 exchanged = topHalf ? float32_t2(lo.real(), lo.imag()) : float32_t2(hi.real(), hi.imag());
60-
shuffleXor<SharedMemoryAdaptor, float32_t2>(exchanged, stride, sharedmemAdaptor);
61-
if (topHalf)
62-
{
63-
lo.real(exchanged.x);
64-
lo.imag(exchanged.y);
6544
}
66-
else
67-
{
68-
hi.real(exchanged.x);
69-
hi.imag(exchanged.y);
70-
}
71-
}
72-
};
73-
74-
template<typename SharedMemoryAdaptor>
75-
struct exchangeValues<SharedMemoryAdaptor, float64_t>
76-
{
77-
static void __call(NBL_REF_ARG(complex_t<float64_t>) lo, NBL_REF_ARG(complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
78-
{
79-
const bool topHalf = bool(threadID & stride);
80-
// pack into `float64_t2` because ternary operator doesn't support structs
81-
float64_t2 exchanged = topHalf ? float64_t2(lo.real(), lo.imag()) : float64_t2(hi.real(), hi.imag());
82-
shuffleXor<SharedMemoryAdaptor, float64_t2 >(exchanged, stride, sharedmemAdaptor);
83-
if (topHalf)
84-
{
85-
lo.real(exchanged.x);
86-
lo.imag(exchanged.y);
87-
}
88-
else
89-
{
90-
hi.real(exchanged.x);
91-
hi.imag(exchanged.y);
92-
}
9345
}
9446
};
9547

@@ -170,7 +122,7 @@ uint32_t getNegativeIndex(uint32_t idx)
170122

171123
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
172124
template<typename Scalar>
173-
void unpack(NBL_CONST_REF_ARG(complex_t<Scalar>) lo, NBL_CONST_REF_ARG(complex_t<Scalar>) hi)
125+
void unpack(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
174126
{
175127
complex_t<Scalar> x = (lo + conj(hi)) * Scalar(0.5);
176128
hi = rotateRight<Scalar>(lo - conj(hi)) * 0.5;

0 commit comments

Comments
 (0)