Skip to content

Commit b3cc63e

Browse files
committed
better "nbl_glsl_ext_FFT_getPaddedData"
1 parent 32eddff commit b3cc63e

File tree

3 files changed

+31
-63
lines changed

3 files changed

+31
-63
lines changed

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,20 @@ vec2 nbl_glsl_ext_FFT_getData(in uvec3 coordinate, in uint channel);
6666
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in vec2 complex_value);
6767
#endif
6868

69+
#ifndef _NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DECLARED_
70+
#define _NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DECLARED_
71+
vec2 nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint channel);
72+
#endif
73+
6974
#ifndef _NBL_GLSL_EXT_FFT_GET_DATA_DEFINED_
7075
#error "You need to define `nbl_glsl_ext_FFT_getData` and mark `_NBL_GLSL_EXT_FFT_GET_DATA_DEFINED_`!"
7176
#endif
7277
#ifndef _NBL_GLSL_EXT_FFT_SET_DATA_DEFINED_
7378
#error "You need to define `nbl_glsl_ext_FFT_setData` and mark `_NBL_GLSL_EXT_FFT_SET_DATA_DEFINED_`!"
7479
#endif
80+
#ifndef _NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DEFINED_
81+
#error "You need to define `nbl_glsl_ext_FFT_getPaddedData` and mark `_NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DEFINED_`!"
82+
#endif
7583

7684
// Count Leading Zeroes (naive?)
7785
uint nbl_glsl_ext_FFT_clz(in uint x)
@@ -127,61 +135,6 @@ uint nbl_glsl_ext_FFT_getDimLength(uvec3 dimension)
127135
return dimension[pc.direction];
128136
}
129137

130-
vec2 nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint channel) {
131-
uint min_x = 0u;
132-
uint max_x = pc.dimension.x + min_x - 1u;
133-
134-
uint min_y = 0u;
135-
uint max_y = pc.dimension.y + min_y - 1u;
136-
137-
uint min_z = 0u;
138-
uint max_z = pc.dimension.z + min_z - 1u;
139-
140-
141-
uvec3 actual_coord = uvec3(0u, 0u, 0u);
142-
143-
if(_NBL_GLSL_EXT_FFT_CLAMP_TO_EDGE_ == pc.padding_type) {
144-
if (coordinate.x < min_x) {
145-
actual_coord.x = 0u;
146-
} else if(coordinate.x > max_x) {
147-
actual_coord.x = pc.dimension.x - 1u;
148-
} else {
149-
actual_coord.x = coordinate.x - min_x;
150-
}
151-
152-
if (coordinate.y < min_y) {
153-
actual_coord.y = 0u;
154-
} else if (coordinate.y > max_y) {
155-
actual_coord.y = pc.dimension.y - 1u;
156-
} else {
157-
actual_coord.y = coordinate.y - min_y;
158-
}
159-
160-
if (coordinate.z < min_z) {
161-
actual_coord.z = 0u;
162-
} else if (coordinate.z > max_z) {
163-
actual_coord.z = pc.dimension.z - 1u;
164-
} else {
165-
actual_coord.z = coordinate.z - min_z;
166-
}
167-
168-
} else if (_NBL_GLSL_EXT_FFT_FILL_WITH_ZERO_ == pc.padding_type) {
169-
170-
if ( coordinate.x < min_x || coordinate.x > max_x ||
171-
coordinate.y < min_y || coordinate.y > max_y ||
172-
coordinate.z < min_z || coordinate.z > max_z ) {
173-
return vec2(0, 0);
174-
}
175-
176-
actual_coord.x = coordinate.x - min_x;
177-
actual_coord.y = coordinate.y - min_y;
178-
actual_coord.z = coordinate.z - min_z;
179-
180-
}
181-
182-
return nbl_glsl_ext_FFT_getData(actual_coord, channel);
183-
}
184-
185138
void nbl_glsl_ext_FFT()
186139
{
187140
// Virtual Threads Calculation

include/nbl/ext/FFT/FFT.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class FFT : public core::TotalInterface
106106
return (paddedInputDimensions.width * paddedInputDimensions.height * paddedInputDimensions.depth * numChannels) * (sizeof(float) * 2);
107107
}
108108

109-
static core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader(video::IVideoDriver* driver, DataType inputType, uint32_t maxPaddedDimensionSize);
109+
static core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader(video::IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize);
110110

111111
_NBL_STATIC_INLINE_CONSTEXPR uint32_t MAX_DESCRIPTOR_COUNT = 2u;
112112
static inline void updateDescriptorSet(

src/nbl/ext/FFT/FFT.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ core::SRange<const video::IGPUDescriptorSetLayout::SBinding> FFT::getDefaultBind
9898
return {bnd, bnd+sizeof(bnd)/sizeof(IGPUDescriptorSetLayout::SBinding)};
9999
}
100100

101-
core::smart_refctd_ptr<video::IGPUSpecializedShader> FFT::createShader(video::IVideoDriver* driver, DataType inputType, uint32_t maxPaddedDimensionSize)
101+
core::smart_refctd_ptr<video::IGPUSpecializedShader> FFT::createShader(video::IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize)
102102
{
103-
assert(core::isPoT(maxPaddedDimensionSize));
103+
uint32_t const maxPaddedDimensionSize = core::roundUpToPoT(maxDimensionSize);
104104

105105
const char* sourceFmt =
106106
R"===(#version 430 core
@@ -130,6 +130,7 @@ layout(local_size_x=_NBL_GLSL_EXT_FFT_BLOCK_SIZE_X_DEFINED_, local_size_y=_NBL_G
130130
131131
#define _NBL_GLSL_EXT_FFT_GET_DATA_DEFINED_
132132
#define _NBL_GLSL_EXT_FFT_SET_DATA_DEFINED_
133+
#define _NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DEFINED_
133134
#include "nbl/builtin/glsl/ext/FFT/fft.glsl"
134135
135136
// Input Descriptor
@@ -187,7 +188,7 @@ layout(set=_NBL_GLSL_EXT_FFT_OUTPUT_SET_DEFINED_, binding=_NBL_GLSL_EXT_FFT_OUTP
187188
188189
// Get/Set Data Function
189190
190-
vec2 nbl_glsl_ext_FFT_getData(in uvec3 coordinate, in uint channel)
191+
nbl_glsl_complex nbl_glsl_ext_FFT_getData(in uvec3 coordinate, in uint channel)
191192
{
192193
vec2 retValue = vec2(0, 0);
193194
#if USE_SSBO_FOR_INPUT > 0
@@ -201,13 +202,27 @@ vec2 nbl_glsl_ext_FFT_getData(in uvec3 coordinate, in uint channel)
201202
return retValue;
202203
}
203204
204-
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in vec2 complex_value)
205+
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in nbl_glsl_complex complex_value)
205206
{
206207
uvec3 dimension = pc.padded_dimension;
207208
uint index = channel * (dimension.x * dimension.y * dimension.z) + coordinate.z * (dimension.x * dimension.y) + coordinate.y * (dimension.x) + coordinate.x;
208209
outData[index].complex_value = complex_value;
209210
}
210211
212+
nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint channel) {
213+
214+
uvec3 max_coord = pc.dimension - uvec3(1u);
215+
uvec3 clamped_coord = min(coordinate, max_coord);
216+
217+
bool is_out_of_range = any(bvec3(coordinate!=clamped_coord));
218+
219+
if (_NBL_GLSL_EXT_FFT_FILL_WITH_ZERO_ == pc.padding_type && is_out_of_range) {
220+
return nbl_glsl_complex(0, 0);
221+
}
222+
223+
return nbl_glsl_ext_FFT_getData(clamped_coord, channel);
224+
}
225+
211226
void main()
212227
{
213228
nbl_glsl_ext_FFT();
@@ -271,9 +286,9 @@ layout(set=0, binding=1) restrict buffer OutBuffer
271286
272287
void main()
273288
{
274-
float power = length(in_data[0].complex_value);
275-
vec2 normalized_data = in_data[gl_GlobalInvocationID.x].complex_value / power;
276-
out_data[gl_GlobalInvocationID.x].complex_value = normalized_data;
289+
float power = length(in_data[0].complex_value);
290+
vec2 normalized_data = in_data[gl_GlobalInvocationID.x].complex_value / power;
291+
out_data[gl_GlobalInvocationID.x].complex_value = normalized_data;
277292
}
278293
)===";
279294

0 commit comments

Comments
 (0)