Skip to content

Commit 51bdd2b

Browse files
committed
Checkpoint: Workgroup FFT functional!
1 parent f85372c commit 51bdd2b

File tree

1 file changed

+27
-19
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+27
-19
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct FFT<2,false, Scalar, device_capabilities>
6666
// Compute the indices only once
6767
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
6868
const uint32_t loIx = threadID;
69-
const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_;
69+
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
7070

7171
// Read lo, hi values from global memory
7272
complex_t<Scalar> lo = accessor.get(loIx);
@@ -119,8 +119,8 @@ struct FFT<2,true, Scalar, device_capabilities>
119119
{
120120
// Compute the indices only once
121121
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
122-
const uint32_t loIx = (glsl::gl_SubgroupID()<<(glsl::gl_SubgroupSizeLog2()+1))+glsl::gl_SubgroupInvocationID();
123-
const uint32_t hiIx = loIx+glsl::gl_SubgroupSize();
122+
const uint32_t loIx = threadID;
123+
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
124124

125125
// Read lo, hi values from global memory
126126
complex_t<Scalar> lo = accessor.get(loIx);
@@ -175,19 +175,19 @@ struct FFT<K, false, Scalar, device_capabilities>
175175
//[unroll(K/2)]
176176
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
177177
{
178-
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
178+
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex();
179179

180180
const uint32_t lsb = virtualThread & (stride - 1);
181181
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
182182
const uint32_t hiIx = loIx | stride;
183183

184-
complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
185-
complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
184+
complex_t<Scalar> lo = accessor.get((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
185+
complex_t<Scalar> hi = accessor.get((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
186186

187187
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
188188

189-
accessor.set(loIx, lo);
190-
accessor.set(hiIx, hi);
189+
accessor.set((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), lo);
190+
accessor.set((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), hi);
191191
}
192192
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
193193
stride >>= 1;
@@ -225,33 +225,41 @@ struct FFT<K, true, Scalar, device_capabilities>
225225
FFT<2,true, Scalar, device_capabilities>::template __call(offsetAccessor,sharedmemAccessor);
226226
}
227227
accessor = offsetAccessor.accessor;
228-
/*
228+
229229
static const uint32_t virtualThreadCount = K >> 1;
230230
static const uint16_t passes = mpl::log2<K>::value - 1;
231-
uint32_t stride = K << 1;
231+
uint32_t stride = 2;
232232
//[unroll(passes)]
233233
for (uint16_t pass = 0; pass < passes; pass++)
234234
{
235235
//[unroll(K/2)]
236236
for (uint32_t virtualThread = 0; virtualThread < virtualThreadCount; virtualThread++)
237237
{
238-
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) + SubgroupContiguousIndex();
238+
const uint32_t virtualThreadID = (virtualThread << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex();
239239

240240
const uint32_t lsb = virtualThread & (stride - 1);
241241
const uint32_t loIx = ((virtualThread ^ lsb) << 1) | lsb;
242242
const uint32_t hiIx = loIx | stride;
243+
244+
complex_t<Scalar> lo = accessor.get((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
245+
complex_t<Scalar> hi = accessor.get((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex());
246+
247+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
243248

244-
complex_t<Scalar> lo = accessor.get(loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
245-
complex_t<Scalar> hi = accessor.get(hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_);
246-
247-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & ((stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) - 1), stride << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_),lo,hi);
248-
249-
accessor.set(loIx, lo);
250-
accessor.set(hiIx, hi);
249+
// Divide by special factor at the end
250+
if (passes - 1 == pass)
251+
{
252+
divides_assign< complex_t<Scalar> > divAss;
253+
divAss(lo, virtualThreadCount);
254+
divAss(hi, virtualThreadCount);
255+
}
256+
257+
accessor.set((loIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), lo);
258+
accessor.set((hiIx << _NBL_HLSL_WORKGROUP_SIZE_LOG_2_) | SubgroupContiguousIndex(), hi);
251259
}
252260
accessor.memoryBarrier(); // no execution barrier just making sure writes propagate to accessor
253261
stride <<= 1;
254-
}*/
262+
}
255263
}
256264
};
257265

0 commit comments

Comments
 (0)