@@ -24,18 +24,14 @@ namespace fft
24
24
25
25
// ---------------------------------- Utils -----------------------------------------------
26
26
template<typename SharedMemoryAdaptor, typename Scalar>
27
- struct exchangeValues;
28
-
29
- template<typename SharedMemoryAdaptor>
30
- struct exchangeValues<SharedMemoryAdaptor, float16_t>
27
+ struct exchangeValues
31
28
{
32
- static void __call (NBL_REF_ARG (complex_t<float16_t >) lo, NBL_REF_ARG (complex_t<float16_t >) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
29
+ static void __call (NBL_REF_ARG (complex_t<Scalar >) lo, NBL_REF_ARG (complex_t<Scalar >) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
33
30
{
34
31
const bool topHalf = bool (threadID & stride);
35
- // Pack two halves into a single uint32_t
36
- uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real (), lo.imag ()) : float16_t2 (hi.real (), hi.imag ()));
37
- shuffleXor<SharedMemoryAdaptor, uint32_t>(toExchange, stride, sharedmemAdaptor);
38
- float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
32
+ // Pack into float vector because ternary operator does not support structs
33
+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
34
+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
39
35
if (topHalf)
40
36
{
41
37
lo.real (exchanged.x);
@@ -45,51 +41,7 @@ struct exchangeValues<SharedMemoryAdaptor, float16_t>
45
41
{
46
42
hi.real (exchanged.x);
47
43
lo.imag (exchanged.y);
48
- }
49
- }
50
- };
51
-
52
- template<typename SharedMemoryAdaptor>
53
- struct exchangeValues<SharedMemoryAdaptor, float32_t>
54
- {
55
- static void __call (NBL_REF_ARG (complex_t<float32_t>) lo, NBL_REF_ARG (complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
56
- {
57
- const bool topHalf = bool (threadID & stride);
58
- // pack into `float32_t2` because ternary operator doesn't support structs
59
- float32_t2 exchanged = topHalf ? float32_t2 (lo.real (), lo.imag ()) : float32_t2 (hi.real (), hi.imag ());
60
- shuffleXor<SharedMemoryAdaptor, float32_t2>(exchanged, stride, sharedmemAdaptor);
61
- if (topHalf)
62
- {
63
- lo.real (exchanged.x);
64
- lo.imag (exchanged.y);
65
44
}
66
- else
67
- {
68
- hi.real (exchanged.x);
69
- hi.imag (exchanged.y);
70
- }
71
- }
72
- };
73
-
74
- template<typename SharedMemoryAdaptor>
75
- struct exchangeValues<SharedMemoryAdaptor, float64_t>
76
- {
77
- static void __call (NBL_REF_ARG (complex_t<float64_t>) lo, NBL_REF_ARG (complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
78
- {
79
- const bool topHalf = bool (threadID & stride);
80
- // pack into `float64_t2` because ternary operator doesn't support structs
81
- float64_t2 exchanged = topHalf ? float64_t2 (lo.real (), lo.imag ()) : float64_t2 (hi.real (), hi.imag ());
82
- shuffleXor<SharedMemoryAdaptor, float64_t2 >(exchanged, stride, sharedmemAdaptor);
83
- if (topHalf)
84
- {
85
- lo.real (exchanged.x);
86
- lo.imag (exchanged.y);
87
- }
88
- else
89
- {
90
- hi.real (exchanged.x);
91
- hi.imag (exchanged.y);
92
- }
93
45
}
94
46
};
95
47
@@ -170,7 +122,7 @@ uint32_t getNegativeIndex(uint32_t idx)
170
122
171
123
// Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
172
124
template<typename Scalar>
173
- void unpack (NBL_CONST_REF_ARG (complex_t<Scalar>) lo, NBL_CONST_REF_ARG (complex_t<Scalar>) hi)
125
+ void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
174
126
{
175
127
complex_t<Scalar> x = (lo + conj (hi)) * Scalar (0.5 );
176
128
hi = rotateRight<Scalar>(lo - conj (hi)) * 0.5 ;
0 commit comments