Skip to content

Commit ee4474f

Browse files
committed
Changed memory accessor pattern for FFT to hide coalescing from user. Fixed bit_cast and a bunch of warnings emitted if using half floats
1 parent 3eb1826 commit ee4474f

File tree

5 files changed

+145
-43
lines changed

5 files changed

+145
-43
lines changed

include/nbl/builtin/hlsl/bit.hlsl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ namespace hlsl
3434
{
3535

3636
template<class T, class U>
37-
T bit_cast(U val)
37+
enable_if_t<sizeof(T) <= sizeof(U), T> bit_cast(U val)
3838
{
39-
static_assert(sizeof(T) <= sizeof(U));
4039
return spirv::bitcast<T, U>(val);
4140
}
4241

include/nbl/builtin/hlsl/complex.hlsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ const static complex_t< SCALAR > multiplies< complex_t< SCALAR > >::identity = {
191191
template<> \
192192
const static complex_t< SCALAR > divides< complex_t< SCALAR > >::identity = { promote< SCALAR , uint32_t>(1), promote< SCALAR , uint32_t>(0)};
193193

194+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t)
195+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t2)
196+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t3)
197+
COMPLEX_ARITHMETIC_IDENTITIES(float16_t4)
194198
COMPLEX_ARITHMETIC_IDENTITIES(float32_t)
195199
COMPLEX_ARITHMETIC_IDENTITIES(float32_t2)
196200
COMPLEX_ARITHMETIC_IDENTITIES(float32_t3)
@@ -287,6 +291,10 @@ COMPLEX_COMPOUND_ASSIGN_IDENTITY(minus, SCALAR) \
287291
COMPLEX_COMPOUND_ASSIGN_IDENTITY(multiplies, SCALAR) \
288292
COMPLEX_COMPOUND_ASSIGN_IDENTITY(divides, SCALAR)
289293

294+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t)
295+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t2)
296+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t3)
297+
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float16_t4)
290298
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t)
291299
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t2)
292300
COMPLEX_COMPOUND_ASSIGN_IDENTITIES(float32_t3)

include/nbl/builtin/hlsl/subgroup/fft.hlsl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ struct FFT<true, Scalar, device_capabilities>
9494
// special last iteration
9595
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
9696
divides_assign< complex_t<Scalar> > divAss;
97-
divAss(lo, doubleSubgroupSize);
98-
divAss(hi, doubleSubgroupSize);
97+
divAss(lo, Scalar(doubleSubgroupSize));
98+
divAss(hi, Scalar(doubleSubgroupSize));
9999
}
100100
};
101101

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
88
#include "nbl/builtin/hlsl/mpl.hlsl"
99
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
10+
#include "nbl/builtin/hlsl/bit.hlsl"
11+
12+
// Caveats
13+
// - Sin and Cos in HLSL take 32-bit floats. Using this library with 64-bit floats works perfectly fine, but DXC will emit warnings
1014

1115
namespace nbl
1216
{
@@ -18,20 +22,77 @@ namespace fft
1822
{
1923

2024
// ---------------------------------- Utils -----------------------------------------------
25+
template<typename SharedMemoryAdaptor, typename Scalar>
26+
struct exchangeValues;
2127

22-
template<typename SharedMemoryAccessor, typename Scalar>
23-
void exchangeValues(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
28+
template<typename SharedMemoryAdaptor>
29+
struct exchangeValues<SharedMemoryAdaptor, float16_t>
2430
{
25-
const bool topHalf = bool(threadID & stride);
26-
// Ternary won't take structs so we use this aux variable
27-
vector <Scalar, 2> toExchange = topHalf ? vector <Scalar, 2>(lo.real(), lo.imag()) : vector <Scalar, 2>(hi.real(), hi.imag());
28-
complex_t<Scalar> toExchangeComplex = {toExchange.x, toExchange.y};
29-
shuffleXor<SharedMemoryAccessor, complex_t<Scalar> >::__call(toExchangeComplex, stride, sharedmemAccessor);
30-
if (topHalf)
31-
lo = toExchangeComplex;
32-
else
33-
hi = toExchangeComplex;
34-
}
31+
static void __call(NBL_REF_ARG(complex_t<float16_t>) lo, NBL_REF_ARG(complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
32+
{
33+
const bool topHalf = bool(threadID & stride);
34+
// Ternary won't take structs so we use this aux variable
35+
uint32_t toExchange = bit_cast<uint32_t, vector <float16_t, 2> >(topHalf ? vector <float16_t, 2> (lo.real(), lo.imag()) : vector <float16_t, 2> (hi.real(), hi.imag()));
36+
shuffleXor<SharedMemoryAdaptor, uint32_t>::__call(toExchange, stride, sharedmemAdaptor);
37+
vector <float16_t, 2> exchanged = bit_cast<vector <float16_t, 2>, uint32_t>(toExchange);
38+
if (topHalf)
39+
{
40+
lo.real(exchanged.x);
41+
lo.imag(exchanged.y);
42+
}
43+
else
44+
{
45+
hi.real(exchanged.x);
46+
lo.imag(exchanged.y);
47+
}
48+
}
49+
};
50+
51+
template<typename SharedMemoryAdaptor>
52+
struct exchangeValues<SharedMemoryAdaptor, float32_t>
53+
{
54+
static void __call(NBL_REF_ARG(complex_t<float32_t>) lo, NBL_REF_ARG(complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
55+
{
56+
const bool topHalf = bool(threadID & stride);
57+
// Ternary won't take structs so we use this aux variable
58+
vector <uint32_t, 2> toExchange = bit_cast<vector <uint32_t, 2>, vector <float32_t, 2> >(topHalf ? vector <float32_t, 2>(lo.real(), lo.imag()) : vector <float32_t, 2>(hi.real(), hi.imag()));
59+
shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 2> >::__call(toExchange, stride, sharedmemAdaptor);
60+
vector <float32_t, 2> exchanged = bit_cast<vector <float32_t, 2>, vector <uint32_t, 2> >(toExchange);
61+
if (topHalf)
62+
{
63+
lo.real(exchanged.x);
64+
lo.imag(exchanged.y);
65+
}
66+
else
67+
{
68+
hi.real(exchanged.x);
69+
hi.imag(exchanged.y);
70+
}
71+
}
72+
};
73+
74+
template<typename SharedMemoryAdaptor>
75+
struct exchangeValues<SharedMemoryAdaptor, float64_t>
76+
{
77+
static void __call(NBL_REF_ARG(complex_t<float64_t>) lo, NBL_REF_ARG(complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
78+
{
79+
const bool topHalf = bool(threadID & stride);
80+
// Ternary won't take structs so we use this aux variable
81+
vector <uint32_t, 4> toExchange = bit_cast<vector <uint32_t, 4>, vector <float64_t, 2> > (topHalf ? vector <float64_t, 2>(lo.real(), lo.imag()) : vector <float64_t, 2>(hi.real(), hi.imag()));
82+
shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 4> >::__call(toExchange, stride, sharedmemAdaptor);
83+
vector <float64_t, 2> exchanged = bit_cast<vector <float64_t, 2>, vector <uint32_t, 4> >(toExchange);
84+
if (topHalf)
85+
{
86+
lo.real(exchanged.x);
87+
lo.imag(exchanged.y);
88+
}
89+
else
90+
{
91+
hi.real(exchanged.x);
92+
hi.imag(exchanged.y);
93+
}
94+
}
95+
};
3596

3697
} //namespace fft
3798

@@ -51,10 +112,10 @@ struct FFT;
51112
template<typename Scalar, class device_capabilities>
52113
struct FFT<2,false, Scalar, device_capabilities>
53114
{
54-
template<typename SharedMemoryAccessor>
55-
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
115+
template<typename SharedMemoryAdaptor>
116+
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
56117
{
57-
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor);
118+
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
58119

59120
// Get twiddle with k = threadID mod stride, halfN = stride
60121
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
@@ -77,18 +138,25 @@ struct FFT<2,false, Scalar, device_capabilities>
77138
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
78139
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
79140
{
141+
// Set up the memory adaptor
142+
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
143+
sharedmemAdaptor.accessor = sharedmemAccessor;
144+
80145
// special first iteration
81146
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
82147

83148
// Run bigger steps until Subgroup-sized
84149
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
85150
{
86-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor);
87-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
151+
FFT_loop< MemoryAdaptor<SharedMemoryAccessor> >(stride, lo, hi, threadID, sharedmemAdaptor);
152+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
88153
}
89154

90155
// special last workgroup-shuffle
91-
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAccessor);
156+
fft::exchangeValues<MemoryAdaptor<SharedMemoryAccessor>, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
157+
158+
// Remember to update the accessor's state
159+
sharedmemAccessor = sharedmemAdaptor.accessor;
92160
}
93161

94162
// Subgroup-sized FFT
@@ -106,13 +174,13 @@ struct FFT<2,false, Scalar, device_capabilities>
106174
template<typename Scalar, class device_capabilities>
107175
struct FFT<2,true, Scalar, device_capabilities>
108176
{
109-
template<typename SharedMemoryAccessor>
110-
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
177+
template<typename SharedMemoryAdaptor>
178+
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
111179
{
112180
// Get twiddle with k = threadID mod stride, halfN = stride
113181
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
114182

115-
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor);
183+
fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call(lo, hi, threadID, stride, sharedmemAdaptor);
116184
}
117185

118186

@@ -135,22 +203,29 @@ struct FFT<2,true, Scalar, device_capabilities>
135203
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
136204
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
137205
{
206+
// Set up the memory adaptor
207+
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
208+
sharedmemAdaptor.accessor = sharedmemAccessor;
209+
138210
// special first workgroup-shuffle
139-
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAccessor);
211+
fft::exchangeValues<MemoryAdaptor<SharedMemoryAccessor>, Scalar>::__call(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
140212

141213
// The bigger steps
142214
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
143215
{
144216
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
145-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
146-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor);
217+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
218+
FFT_loop< MemoryAdaptor<SharedMemoryAccessor> >(stride, lo, hi, threadID, sharedmemAdaptor);
147219
}
148220

149221
// special last iteration
150222
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
151223
divides_assign< complex_t<Scalar> > divAss;
152-
divAss(lo, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize());
153-
divAss(hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize());
224+
divAss(lo, Scalar(_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize()));
225+
divAss(hi, Scalar(_NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize()));
226+
227+
// Remember to update the accessor's state
228+
sharedmemAccessor = sharedmemAdaptor.accessor;
154229
}
155230

156231
// Put values back in global mem
@@ -166,10 +241,10 @@ struct FFT<K, false, Scalar, device_capabilities>
166241
template<typename Accessor, typename SharedMemoryAccessor>
167242
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
168243
{
169-
for (uint32_t stride = (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1)
244+
for (uint32_t stride = (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; stride > _NBL_HLSL_WORKGROUP_SIZE_; stride >>= 1)
170245
{
171246
//[unroll(K/2)]
172-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
247+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
173248
{
174249
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
175250
const uint32_t hiIx = loIx | stride;
@@ -223,7 +298,7 @@ struct FFT<K, true, Scalar, device_capabilities>
223298
{
224299
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
225300
//[unroll(K/2)]
226-
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
301+
for (uint32_t virtualThreadID = SubgroupContiguousIndex(); virtualThreadID < (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_; virtualThreadID += _NBL_HLSL_WORKGROUP_SIZE_)
227302
{
228303
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
229304
const uint32_t hiIx = loIx | stride;
@@ -235,11 +310,11 @@ struct FFT<K, true, Scalar, device_capabilities>
235310
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);
236311

237312
// Divide by special factor at the end
238-
if ( (K >> 1) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
313+
if ( (K / 2) * _NBL_HLSL_WORKGROUP_SIZE_ == stride)
239314
{
240315
divides_assign< complex_t<Scalar> > divAss;
241-
divAss(lo, K >> 1);
242-
divAss(hi, K >> 1);
316+
divAss(lo, K / 2);
317+
divAss(hi, K / 2);
243318
}
244319

245320
accessor.set(loIx, lo);

include/nbl/builtin/hlsl/workgroup/shuffle.hlsl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,42 @@ namespace hlsl
1414
namespace workgroup
1515
{
1616

17-
template<typename SharedMemoryAccessor, typename T>
17+
template<typename SharedMemoryAdaptor, typename T>
1818
struct shuffleXor
1919
{
20-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
20+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
2121
{
22-
sharedmemAccessor.set(threadID, value);
22+
sharedmemAdaptor.template set<T>(threadID, value);
2323

2424
// Wait until all writes are done before reading
25-
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
25+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
2626

27-
sharedmemAccessor.get(threadID ^ mask, value);
27+
sharedmemAdaptor.template get<T>(threadID ^ mask, value);
2828
}
2929

30-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
30+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
3131
{
32-
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAccessor);
32+
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
33+
}
34+
};
35+
36+
// Vector specialization
37+
template<typename SharedMemoryAdaptor, typename T, uint32_t N>
38+
struct shuffleXor<SharedMemoryAdaptor, vector <T, N> >
39+
{
40+
static enable_if_t<N <= 4> __call(NBL_REF_ARG(vector <T, N>) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
41+
{
42+
sharedmemAdaptor.template set<T>(threadID, value);
43+
44+
// Wait until all writes are done before reading
45+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
46+
47+
sharedmemAdaptor.template get<T>(threadID ^ mask, value);
48+
}
49+
50+
static enable_if_t<N <= 4> __call(NBL_REF_ARG(vector <T, N>) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)
51+
{
52+
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
3353
}
3454
};
3555

0 commit comments

Comments
 (0)