@@ -18,7 +18,8 @@ namespace FFT
18
18
{
19
19
20
20
typedef uint32_t uint;
21
- struct alignas (16 ) uvec3 {
21
+ struct alignas (16 ) uvec3
22
+ {
22
23
uint x,y,z;
23
24
};
24
25
struct alignas (16 ) uvec4 {
@@ -50,45 +51,85 @@ class FFT final : public core::IReferenceCounted
50
51
{
51
52
uint32_t passesRequired = 0u ;
52
53
54
+ const auto paddedInputDimensions = padDimensions (extraPaddedInputDimensions);
55
+
56
+ using SizeAxisPair = std::tuple<uint32_t ,uint8_t ,uint8_t >;
57
+ std::array<SizeAxisPair,3u > passes;
53
58
if (numChannels)
54
59
{
55
- const auto paddedInputDimensions = padDimensions (extraPaddedInputDimensions);
56
60
for (uint32_t i=0u ; i<3u ; i++)
57
- if ((&inputDimensions.width )[i]>1u )
58
61
{
59
- // TODO: rework
60
- auto & dispatch = outInfos[passesRequired];
61
- dispatch.workGroupCount [0 ] = paddedInputDimensions.width ;
62
- dispatch.workGroupCount [1 ] = paddedInputDimensions.height ;
63
- dispatch.workGroupCount [2 ] = paddedInputDimensions.depth ;
64
- dispatch.workGroupCount [i] = 1u ;
65
-
66
- auto & params = outParams[passesRequired];
67
- params.input_dimensions .x = inputDimensions.width ;
68
- params.input_dimensions .y = inputDimensions.height ;
69
- params.input_dimensions .z = inputDimensions.depth ;
62
+ auto dim = (&paddedInputDimensions.width )[i];
63
+ if (dim<2u )
64
+ continue ;
65
+ passes[passesRequired++] = {dim,i,paddingType[i]};
66
+ }
67
+ std::sort (passes.begin (),passes.begin ()+passesRequired,[](const auto & lhs, const auto & rhs)->bool {return std::get<0u >(lhs)>std::get<0u >(rhs);});
68
+ }
69
+
70
+ auto computeOutputStride = [](const uvec3& output_dimensions, const auto axis, const auto nextAxis) -> uvec4
71
+ {
72
+ // coord[axis] = 1u
73
+ // coord[nextAxis] = fftLen;
74
+ // coord[otherAxis] = fftLen*dimension[nextAxis];
75
+ uvec4 stride;
76
+ stride.w = output_dimensions.x *output_dimensions.y *output_dimensions.z ;
77
+ for (auto i=0u ; i<3u ; i++)
78
+ {
79
+ auto & coord = (&stride.x )[i];
80
+ if (i!=axis)
81
+ {
82
+ coord = (&output_dimensions.x )[axis];
83
+ if (i!=nextAxis)
84
+ coord *= (&output_dimensions.x )[nextAxis];
85
+ }
86
+ else
87
+ coord = 1u ;
88
+ }
89
+ return stride;
90
+ };
91
+
92
+ if (passesRequired)
93
+ {
94
+ uvec3 output_dimensions = {inputDimensions.width ,inputDimensions.height ,inputDimensions.depth };
95
+ for (uint32_t i=0u ; i<passesRequired; i++)
96
+ {
97
+ auto & params = outParams[i];
98
+ params.input_dimensions .x = output_dimensions.x ;
99
+ params.input_dimensions .y = output_dimensions.y ;
100
+ params.input_dimensions .z = output_dimensions.z ;
101
+
102
+ const auto paddedAxisLen = std::get<0u >(passes[i]);
70
103
{
71
- const uint32_t fftSize = (&paddedInputDimensions.width )[i];
72
104
assert (paddingType[i]<=asset::ISampler::ETC_MIRROR);
73
105
params.input_dimensions .w = (isInverse ? 0x80000000u :0x0u )|
74
106
(i<<28u )| // direction
75
107
((numChannels-1u )<<26u )| // max channel
76
- (core::findMSB (fftSize )<<3u )| // log2(fftSize)
77
- uint32_t (paddingType [i]);
108
+ (core::findMSB (paddedAxisLen )<<3u )| // log2(fftSize)
109
+ uint32_t (std::get< 2u >(passes [i]) );
78
110
}
79
- params.input_strides .x = 1u ;
80
- params.input_strides .y = paddedInputDimensions.width ;
81
- params.input_strides .z = params.input_strides .y *paddedInputDimensions.height ;
82
- params.input_strides .w = params.input_strides .z *paddedInputDimensions.depth ;
83
- params.output_strides = params.input_strides ;
84
111
85
- passesRequired++;
112
+ const auto passAxis = std::get<1u >(passes[i]);
113
+ (&output_dimensions.x )[passAxis] = paddedAxisLen;
114
+ if (i)
115
+ params.input_strides = outParams[i-1u ].output_strides ;
116
+ else // TODO provide an override for input strides
117
+ {
118
+ params.input_strides .x = 1u ;
119
+ params.input_strides .y = inputDimensions.width ;
120
+ params.input_strides .z = params.input_strides .y * inputDimensions.height ;
121
+ params.input_strides .w = params.input_strides .z * inputDimensions.depth ;
122
+ }
123
+ params.output_strides = computeOutputStride (output_dimensions,passAxis,std::get<1u >(passes[(i+1u )%passesRequired]));
124
+
125
+ auto & dispatch = outInfos[i];
126
+ dispatch.workGroupCount [0 ] = output_dimensions.x ;
127
+ dispatch.workGroupCount [1 ] = output_dimensions.y ;
128
+ dispatch.workGroupCount [2 ] = output_dimensions.z ;
129
+ dispatch.workGroupCount [passAxis] = 1u ;
86
130
}
87
131
}
88
132
89
- if (passesRequired)
90
- outParams[passesRequired-1u ].output_strides = outParams[0 ].input_strides ;
91
-
92
133
return passesRequired;
93
134
}
94
135
static inline uint32_t buildParameters (
0 commit comments