4
4
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
5
5
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
6
6
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
7
- #include "nbl/builtin/hlsl/memory_accessor.hlsl"
8
7
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
9
8
10
9
namespace nbl
@@ -18,43 +17,41 @@ namespace fft
18
17
// ---------------------------------- Utils -----------------------------------------------
19
18
20
19
template<typename SharedMemoryAccessor, typename Scalar>
21
- void exchangeValues (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (MemoryAdaptor< SharedMemoryAccessor>) sharedmemAdaptor )
20
+ void exchangeValues (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor )
22
21
{
23
22
const bool topHalf = bool (threadID & stride);
23
+ // Ternary won't take structs so we use this aux variable
24
24
vector <Scalar, 2 > toExchange = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
25
- shuffleXor<SharedMemoryAccessor, vector <Scalar, 2 > >::__call (toExchange, stride, sharedmemAdaptor);
25
+ complex_t<Scalar> toExchangeComplex = {toExchange.x, toExchange.y};
26
+ shuffleXor<SharedMemoryAccessor, complex_t<Scalar> >::__call (toExchangeComplex, stride, sharedmemAccessor);
26
27
if (topHalf)
27
- {
28
- lo.real (toExchange.x);
29
- lo.imag (toExchange.y);
30
- }
28
+ lo = toExchangeComplex;
31
29
else
32
- {
33
- hi.real (toExchange.x);
34
- hi.imag (toExchange.y);
35
- }
30
+ hi = toExchangeComplex;
36
31
}
37
32
33
+ } //namespace fft
34
+
38
35
// ----------------------------------- End Utils -----------------------------------------------
39
36
40
37
template<uint16_t ElementsPerInvocation, bool Inverse, typename Scalar, class device_capabilities=void >
41
38
struct FFT;
42
39
43
40
// For the FFT methods below, we assume:
44
41
// - Accessor is a global memory accessor to an array fitting 2 * _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
45
- // that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The data layout is assumed to be a whole array of real parts
46
- // followed by a whole array of imaginary parts. So it would be something like
47
- // [x_0, x_1, ..., x_{2 * _NBL_HLSL_WORKGROUP_SIZE_}, y_0, y_1, ..., y_{2 * _NBL_HLSL_WORKGROUP_SIZE_}]
48
- // - SharedMemoryAccessor accesses a shared memory array that can fit _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, so 2 * _NBL_HLSL_WORKGROUP_SIZE_ Scalars
42
+ // that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread.
43
+ // There are no assumptions on the data layout: we just require the accessor to provide get and set methods for complex_t<Scalar>.
44
+ // - SharedMemoryAccessor accesses a shared memory array that can fit _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, with get and set
45
+ // methods for complex_t<Scalar>. It benefits from coalesced accesses
49
46
50
47
// 2 items per invocation forward specialization
51
48
template<typename Scalar, class device_capabilities>
52
49
struct FFT<2 ,false , Scalar, device_capabilities>
53
50
{
54
51
template<typename SharedMemoryAccessor>
55
- static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (MemoryAdaptor< SharedMemoryAccessor>) sharedmemAdaptor )
52
+ static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor )
56
53
{
57
- exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor );
54
+ fft:: exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor );
58
55
59
56
// Get twiddle with k = threadID mod stride, halfN = stride
60
57
hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -64,25 +61,14 @@ struct FFT<2,false, Scalar, device_capabilities>
64
61
template<typename Accessor, typename SharedMemoryAccessor>
65
62
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
66
63
{
67
- // Set up the MemAdaptors
68
- MemoryAdaptor<Accessor, 1 > memAdaptor;
69
- memAdaptor.accessor = accessor;
70
- MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
71
- sharedmemAdaptor.accessor = sharedmemAccessor;
72
-
73
64
// Compute the indices only once
74
65
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
75
66
const uint32_t loIx = threadID;
76
67
const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_;
77
68
78
69
// Read lo, hi values from global memory
79
- vector <Scalar, 2 > loVec;
80
- vector <Scalar, 2 > hiVec;
81
- // TODO: if we get rid of the Memory Adaptor on the accessor and require comples getters and setters, then no `2*`
82
- memAdaptor.get (2 * loIx , loVec);
83
- memAdaptor.get (2 * hiIx, hiVec);
84
- complex_t<Scalar> lo = {loVec.x, loVec.y};
85
- complex_t<Scalar> hi = {hiVec.x, hiVec.y};
70
+ complex_t<Scalar> lo = accessor.get (loIx);
71
+ complex_t<Scalar> hi = accessor.get (hiIx);
86
72
87
73
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
88
74
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
@@ -93,27 +79,20 @@ struct FFT<2,false, Scalar, device_capabilities>
93
79
// Run bigger steps until Subgroup-sized
94
80
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
95
81
{
96
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor );
97
- sharedmemAdaptor .workgroupExecutionAndMemoryBarrier ();
82
+ FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor );
83
+ sharedmemAccessor .workgroupExecutionAndMemoryBarrier ();
98
84
}
99
85
100
86
// special last workgroup-shuffle
101
- exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor );
87
+ fft:: exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAccessor );
102
88
}
103
89
104
90
// Subgroup-sized FFT
105
- subgroup::fft:: FFT<false , Scalar, device_capabilities>::__call (lo, hi);
91
+ subgroup::FFT<false , Scalar, device_capabilities>::__call (lo, hi);
106
92
107
93
// Put values back in global mem
108
- loVec = vector <Scalar, 2 >(lo.real (), lo.imag ());
109
- hiVec = vector <Scalar, 2 >(hi.real (), hi.imag ());
110
-
111
- memAdaptor.set (2 * loIx, loVec);
112
- memAdaptor.set (2 * hiIx, hiVec);
113
-
114
- // Update state for accessors
115
- accessor = memAdaptor.accessor;
116
- sharedmemAccessor = sharedmemAdaptor.accessor;
94
+ accessor.set (loIx, lo);
95
+ accessor.set (hiIx, hi);
117
96
}
118
97
};
119
98
@@ -124,53 +103,42 @@ template<typename Scalar, class device_capabilities>
124
103
struct FFT<2 ,true , Scalar, device_capabilities>
125
104
{
126
105
template<typename SharedMemoryAccessor>
127
- static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (MemoryAdaptor< SharedMemoryAccessor>) sharedmemAdaptor )
106
+ static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor )
128
107
{
129
108
// Get twiddle with k = threadID mod stride, halfN = stride
130
109
hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
131
110
132
- exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor );
111
+ fft:: exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor );
133
112
}
134
113
135
114
136
115
template<typename Accessor, typename SharedMemoryAccessor>
137
116
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
138
117
{
139
- // Set up the MemAdaptors
140
- MemoryAdaptor<Accessor, 1 > memAdaptor;
141
- memAdaptor.accessor = accessor;
142
- MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
143
- sharedmemAdaptor.accessor = sharedmemAccessor;
144
-
145
118
// Compute the indices only once
146
119
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
147
120
const uint32_t loIx = (glsl::gl_SubgroupID ()<<(glsl::gl_SubgroupSizeLog2 ()+1 ))+glsl::gl_SubgroupInvocationID ();
148
121
const uint32_t hiIx = loIx+glsl::gl_SubgroupSize ();
149
122
150
123
// Read lo, hi values from global memory
151
- vector <Scalar, 2 > loVec;
152
- vector <Scalar, 2 > hiVec;
153
- memAdaptor.get (2 * loIx , loVec);
154
- memAdaptor.get (2 * hiIx, hiVec);
155
- complex_t<Scalar> lo = {loVec.x, loVec.y};
156
- complex_t<Scalar> hi = {hiVec.x, hiVec.y};
124
+ complex_t<Scalar> lo = accessor.get (loIx);
125
+ complex_t<Scalar> hi = accessor.get (hiIx);
157
126
158
127
// Run a subgroup-sized FFT, then continue with bigger steps
159
- subgroup::fft:: FFT<true , Scalar, device_capabilities>::__call (lo, hi);
128
+ subgroup::FFT<true , Scalar, device_capabilities>::__call (lo, hi);
160
129
161
130
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
162
-
163
131
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
164
132
{
165
133
// special first workgroup-shuffle
166
- exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor );
134
+ fft:: exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAccessor );
167
135
168
136
// The bigger steps
169
137
for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1 )
170
138
{
171
139
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
172
- sharedmemAdaptor .workgroupExecutionAndMemoryBarrier ();
173
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor );
140
+ sharedmemAccessor .workgroupExecutionAndMemoryBarrier ();
141
+ FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor );
174
142
}
175
143
176
144
// special last iteration
@@ -181,14 +149,8 @@ struct FFT<2,true, Scalar, device_capabilities>
181
149
}
182
150
183
151
// Put values back in global mem
184
- loVec = vector <Scalar, 2 >(lo.real (), lo.imag ());
185
- hiVec = vector <Scalar, 2 >(hi.real (), hi.imag ());
186
- memAdaptor.set (2 * loIx, loVec);
187
- memAdaptor.set (2 * hiIx, hiVec);
188
-
189
- // Update state for accessors
190
- accessor = memAdaptor.accessor;
191
- sharedmemAccessor = sharedmemAdaptor.accessor;
152
+ accessor.set (loIx, lo);
153
+ accessor.set (hiIx, hi);
192
154
}
193
155
};
194
156
@@ -221,6 +183,5 @@ struct FFT
221
183
}
222
184
}
223
185
}
224
- }
225
186
226
187
#endif
0 commit comments