Skip to content

Commit 94a19eb

Browse files
committed
QOL update to workgroup FFT's exchange value for readability
1 parent ee4474f commit 94a19eb

File tree

1 file changed

+8
-8
lines changed
  • include/nbl/builtin/hlsl/workgroup

1 file changed

+8
-8
lines changed

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ struct exchangeValues<SharedMemoryAdaptor, float16_t>
3232
{
3333
const bool topHalf = bool(threadID & stride);
3434
// Ternary won't take structs so we use this aux variable
35-
uint32_t toExchange = bit_cast<uint32_t, vector <float16_t, 2> >(topHalf ? vector <float16_t, 2> (lo.real(), lo.imag()) : vector <float16_t, 2> (hi.real(), hi.imag()));
35+
uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real(), lo.imag()) : float16_t2 (hi.real(), hi.imag()));
3636
shuffleXor<SharedMemoryAdaptor, uint32_t>::__call(toExchange, stride, sharedmemAdaptor);
37-
vector <float16_t, 2> exchanged = bit_cast<vector <float16_t, 2>, uint32_t>(toExchange);
37+
float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
3838
if (topHalf)
3939
{
4040
lo.real(exchanged.x);
@@ -55,9 +55,9 @@ struct exchangeValues<SharedMemoryAdaptor, float32_t>
5555
{
5656
const bool topHalf = bool(threadID & stride);
5757
// Ternary won't take structs so we use this aux variable
58-
vector <uint32_t, 2> toExchange = bit_cast<vector <uint32_t, 2>, vector <float32_t, 2> >(topHalf ? vector <float32_t, 2>(lo.real(), lo.imag()) : vector <float32_t, 2>(hi.real(), hi.imag()));
59-
shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 2> >::__call(toExchange, stride, sharedmemAdaptor);
60-
vector <float32_t, 2> exchanged = bit_cast<vector <float32_t, 2>, vector <uint32_t, 2> >(toExchange);
58+
uint32_t2 toExchange = bit_cast<uint32_t2, float32_t2 >(topHalf ? float32_t2(lo.real(), lo.imag()) : float32_t2(hi.real(), hi.imag()));
59+
shuffleXor<SharedMemoryAdaptor, uint32_t2 >::__call(toExchange, stride, sharedmemAdaptor);
60+
float32_t2 exchanged = bit_cast<float32_t2, uint32_t2 >(toExchange);
6161
if (topHalf)
6262
{
6363
lo.real(exchanged.x);
@@ -78,9 +78,9 @@ struct exchangeValues<SharedMemoryAdaptor, float64_t>
7878
{
7979
const bool topHalf = bool(threadID & stride);
8080
// Ternary won't take structs so we use this aux variable
81-
vector <uint32_t, 4> toExchange = bit_cast<vector <uint32_t, 4>, vector <float64_t, 2> > (topHalf ? vector <float64_t, 2>(lo.real(), lo.imag()) : vector <float64_t, 2>(hi.real(), hi.imag()));
82-
shuffleXor<SharedMemoryAdaptor, vector <uint32_t, 4> >::__call(toExchange, stride, sharedmemAdaptor);
83-
vector <float64_t, 2> exchanged = bit_cast<vector <float64_t, 2>, vector <uint32_t, 4> >(toExchange);
81+
uint32_t4 toExchange = bit_cast<uint32_t4, float64_t2 > (topHalf ? float64_t2(lo.real(), lo.imag()) : float64_t2(hi.real(), hi.imag()));
82+
shuffleXor<SharedMemoryAdaptor, uint32_t4 >::__call(toExchange, stride, sharedmemAdaptor);
83+
float64_t2 exchanged = bit_cast<float64_t2, uint32_t4 >(toExchange);
8484
if (topHalf)
8585
{
8686
lo.real(exchanged.x);

0 commit comments

Comments
 (0)