@@ -11,28 +11,18 @@ namespace nbl
11
11
{
12
12
namespace hlsl
13
13
{
14
-
15
- namespace glsl
16
- {
17
-
18
- // Define this method from glsl_compat/core.hlsl
19
- uint32_t3 gl_WorkGroupSize () {
20
- return uint32_t3 (_NBL_HLSL_WORKGROUP_SIZE_, 1 , 1 );
21
- }
22
-
23
- } //namespace glsl
24
-
25
14
namespace workgroup
26
15
{
27
-
16
+ namespace fft
17
+ {
28
18
// ---------------------------------- Utils -----------------------------------------------
29
19
30
20
template<typename SharedMemoryAccessor, typename Scalar>
31
21
void exchangeValues (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
32
22
{
33
23
const bool topHalf = bool (threadID & stride);
34
24
vector <Scalar, 2 > toExchange = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
35
- shuffleXor<SharedMemoryAccessor, Scalar, 2 >(toExchange, stride, threadID , sharedmemAdaptor);
25
+ shuffleXor<SharedMemoryAccessor, vector < Scalar, 2 > >:: __call (toExchange, stride, sharedmemAdaptor);
36
26
if (topHalf)
37
27
{
38
28
lo.real (toExchange.x);
@@ -67,60 +57,59 @@ struct FFT<2,false, Scalar, device_capabilities>
67
57
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
68
58
69
59
// Get twiddle with k = threadID mod stride, halfN = stride
70
- fft::DIF<Scalar>::radix2 (fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
60
+ hlsl:: fft::DIF<Scalar>::radix2 (hlsl:: fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
71
61
}
72
62
73
63
74
64
template<typename Accessor, typename SharedMemoryAccessor>
75
65
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
76
66
{
77
67
// Set up the MemAdaptors
78
- MemoryAdaptor<Accessor, _NBL_HLSL_WORKGROUP_SIZE_ << 1 > memAdaptor;
68
+ MemoryAdaptor<Accessor, 1 > memAdaptor;
79
69
memAdaptor.accessor = accessor;
80
70
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
81
71
sharedmemAdaptor.accessor = sharedmemAccessor;
82
72
83
- // Compute the SubgroupContiguousIndex only once
73
+ // Compute the indices only once
84
74
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
75
+ const uint32_t loIx = threadID;
76
+ const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_;
85
77
86
78
// Read lo, hi values from global memory
87
79
vector <Scalar, 2 > loVec;
88
80
vector <Scalar, 2 > hiVec;
89
- memAdaptor.get (threadID, loVec);
90
- memAdaptor.get (threadID + _NBL_HLSL_WORKGROUP_SIZE_, hiVec);
81
+ // TODO: if we get rid of the Memory Adaptor on the accessor and require comples getters and setters, then no `2*`
82
+ memAdaptor.get (2 * loIx , loVec);
83
+ memAdaptor.get (2 * hiIx, hiVec);
91
84
complex_t<Scalar> lo = {loVec.x, loVec.y};
92
85
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
93
86
94
- // special first iteration - only if workgroupsize > subgroupsize
87
+ // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
95
88
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
96
- fft::DIF<Scalar>::radix2 (fft::twiddle<false , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
89
+ {
90
+ // special first iteration
91
+ hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
97
92
98
- // Run bigger steps until Subgroup-sized
99
- for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
100
- {
101
- // If at least one loop was executed, we must wait for all threads to get their values before we write to shared mem again
102
- if ( !(stride & (_NBL_HLSL_WORKGROUP_SIZE_ >> 1 )) )
93
+ // Run bigger steps until Subgroup-sized
94
+ for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1 ; stride > glsl::gl_SubgroupSize (); stride >>= 1 )
95
+ {
96
+ FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
103
97
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
104
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
105
- }
98
+ }
106
99
107
- // special last workgroup-shuffle - only if workgroupsize > subgroupsize
108
- if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
109
- {
110
- // Wait for all threads to be done with reads in the last loop before writing to shared mem
111
- sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
112
- exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
113
- }
100
+ // special last workgroup-shuffle
101
+ exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
102
+ }
114
103
115
104
// Subgroup-sized FFT
116
- subgroup::FFT<false , Scalar, device_capabilities>::__call (lo, hi);
105
+ subgroup::fft:: FFT<false , Scalar, device_capabilities>::__call (lo, hi);
117
106
118
107
// Put values back in global mem
119
108
loVec = vector <Scalar, 2 >(lo.real (), lo.imag ());
120
109
hiVec = vector <Scalar, 2 >(hi.real (), hi.imag ());
121
110
122
- memAdaptor.set (threadID , loVec);
123
- memAdaptor.set (threadID + _NBL_HLSL_WORKGROUP_SIZE_ , hiVec);
111
+ memAdaptor.set (2 * loIx , loVec);
112
+ memAdaptor.set (2 * hiIx , hiVec);
124
113
125
114
// Update state for accessors
126
115
accessor = memAdaptor.accessor;
@@ -138,7 +127,7 @@ struct FFT<2,true, Scalar, device_capabilities>
138
127
static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
139
128
{
140
129
// Get twiddle with k = threadID mod stride, halfN = stride
141
- fft::DIF <Scalar>::radix2 (fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
130
+ hlsl:: fft::DIT <Scalar>::radix2 (hlsl:: fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
142
131
143
132
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
144
133
}
@@ -148,53 +137,54 @@ struct FFT<2,true, Scalar, device_capabilities>
148
137
static void __call (NBL_REF_ARG (Accessor) accessor, NBL_REF_ARG (SharedMemoryAccessor) sharedmemAccessor)
149
138
{
150
139
// Set up the MemAdaptors
151
- MemoryAdaptor<Accessor, _NBL_HLSL_WORKGROUP_SIZE_ << 1 > memAdaptor;
140
+ MemoryAdaptor<Accessor, 1 > memAdaptor;
152
141
memAdaptor.accessor = accessor;
153
142
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
154
143
sharedmemAdaptor.accessor = sharedmemAccessor;
155
144
156
- // Compute the SubgroupContiguousIndex only once
145
+ // Compute the indices only once
157
146
const uint32_t threadID = uint32_t (SubgroupContiguousIndex ());
147
+ const uint32_t loIx = (glsl::gl_SubgroupID ()<<(glsl::gl_SubgroupSizeLog2 ()+1 ))+glsl::gl_SubgroupInvocationID ();
148
+ const uint32_t hiIx = loIx+glsl::gl_SubgroupSize ();
158
149
159
150
// Read lo, hi values from global memory
160
151
vector <Scalar, 2 > loVec;
161
152
vector <Scalar, 2 > hiVec;
162
- memAdaptor.get (threadID , loVec);
163
- memAdaptor.get (threadID + _NBL_HLSL_WORKGROUP_SIZE_ , hiVec);
153
+ memAdaptor.get (2 * loIx , loVec);
154
+ memAdaptor.get (2 * hiIx , hiVec);
164
155
complex_t<Scalar> lo = {loVec.x, loVec.y};
165
156
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
166
157
167
158
// Run a subgroup-sized FFT, then continue with bigger steps
168
- subgroup::FFT<true , Scalar, device_capabilities>::__call (lo, hi);
159
+ subgroup::fft:: FFT<true , Scalar, device_capabilities>::__call (lo, hi);
169
160
170
- // special first workgroup-shuffle - only if workgroupsize > subgroupsize
161
+ // If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
162
+
171
163
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
172
164
{
165
+ // special first workgroup-shuffle
173
166
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
174
- }
175
-
176
- // The bigger steps
177
- for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1 )
178
- {
179
- // If we enter this for loop, then the special first workgroup shuffle went through, so wait on that
180
- sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
181
- FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
182
- }
167
+
168
+ // The bigger steps
169
+ for (uint32_t stride = glsl::gl_SubgroupSize () << 1 ; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1 )
170
+ {
171
+ // Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
172
+ sharedmemAdaptor.workgroupExecutionAndMemoryBarrier ();
173
+ FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
174
+ }
183
175
184
- // special last iteration - only if workgroupsize > subgroupsize
185
- if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize ())
186
- {
187
- fft::DIT<Scalar>::radix2 (fft::twiddle<true , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
176
+ // special last iteration
177
+ hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
188
178
divides_assign< complex_t<Scalar> > divAss;
189
179
divAss (lo, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ());
190
- divAss (hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ());
191
- }
180
+ divAss (hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize ());
181
+ }
192
182
193
183
// Put values back in global mem
194
184
loVec = vector <Scalar, 2 >(lo.real (), lo.imag ());
195
185
hiVec = vector <Scalar, 2 >(hi.real (), hi.imag ());
196
- memAdaptor.set (threadID , loVec);
197
- memAdaptor.set (threadID + _NBL_HLSL_WORKGROUP_SIZE_ , hiVec);
186
+ memAdaptor.set (2 * loIx , loVec);
187
+ memAdaptor.set (2 * hiIx , hiVec);
198
188
199
189
// Update state for accessors
200
190
accessor = memAdaptor.accessor;
@@ -203,21 +193,6 @@ struct FFT<2,true, Scalar, device_capabilities>
203
193
};
204
194
205
195
206
-
207
-
208
-
209
-
210
-
211
-
212
-
213
-
214
-
215
-
216
-
217
-
218
-
219
-
220
-
221
196
// ---------------------------- Below pending --------------------------------------------------
222
197
223
198
/*
@@ -246,5 +221,6 @@ struct FFT
246
221
}
247
222
}
248
223
}
224
+ }
249
225
250
226
#endif
0 commit comments