Skip to content

Commit d4c9bad

Browse files
Merge pull request #556 from theoreticalphysicsftw/blit_hlsl_port
First draft of Blit GLSL to HLSL migration.
2 parents 6b96a1b + c21012a commit d4c9bad

File tree

14 files changed

+2418
-760
lines changed

14 files changed

+2418
-760
lines changed

include/nbl/asset/utils/CHLSLCompiler.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,65 @@ class NBL_API2 CHLSLCompiler final : public IShaderCompiler
4949
std::string preprocessShader(std::string&& code, IShader::E_SHADER_STAGE& stage, const SPreprocessorOptions& preprocessOptions) const override;
5050

5151
void insertIntoStart(std::string& code, std::ostringstream&& ins) const override;
52+
53+
static inline const char* getStorageImageFormatQualifier(const asset::E_FORMAT format)
54+
{
55+
switch (format)
56+
{
57+
case asset::EF_R32G32B32A32_SFLOAT:
58+
return "rgba32f";
59+
case asset::EF_R16G16B16A16_SFLOAT:
60+
return "rgba16f";
61+
case asset::EF_R32G32_SFLOAT:
62+
return "rg32f";
63+
case asset::EF_R16G16_SFLOAT:
64+
return "rg16f";
65+
case asset::EF_B10G11R11_UFLOAT_PACK32:
66+
return "r11g11b10f";
67+
case asset::EF_R32_SFLOAT:
68+
return "r32f";
69+
case asset::EF_R16_SFLOAT:
70+
return "r16f";
71+
case asset::EF_R16G16B16A16_UNORM:
72+
return "rgba16";
73+
case asset::EF_A2B10G10R10_UNORM_PACK32:
74+
return "rgb10a2";
75+
case asset::EF_R8G8B8A8_UNORM:
76+
return "rgba8";
77+
case asset::EF_R16G16_UNORM:
78+
return "rg16";
79+
case asset::EF_R8G8_UNORM:
80+
return "rg8";
81+
case asset::EF_R16_UNORM:
82+
return "r16";
83+
case asset::EF_R8_UNORM:
84+
return "r8";
85+
case asset::EF_R16G16B16A16_SNORM:
86+
return "rgba16snorm";
87+
case asset::EF_R8G8B8A8_SNORM:
88+
return "rgba8snorm";
89+
case asset::EF_R16G16_SNORM:
90+
return "rg16snorm";
91+
case asset::EF_R8G8_SNORM:
92+
return "rg8snorm";
93+
case asset::EF_R16_SNORM:
94+
return "r16snorm";
95+
case asset::EF_R8_UINT:
96+
return "r8ui";
97+
case asset::EF_R16_UINT:
98+
return "r16ui";
99+
case asset::EF_R32_UINT:
100+
return "r32ui";
101+
case asset::EF_R32G32_UINT:
102+
return "rg32ui";
103+
case asset::EF_R32G32B32A32_UINT:
104+
return "rgba32ui";
105+
default:
106+
assert(false);
107+
return "";
108+
}
109+
}
110+
52111
protected:
53112

54113
// This can't be a unique_ptr due to it being an undefined type
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_BLIT_ALPHA_TEST_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_BLIT_ALPHA_TEST_INCLUDED_
6+
7+
8+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
9+
10+
namespace nbl
11+
{
12+
namespace hlsl
13+
{
14+
namespace blit
15+
{
16+
17+
18+
template <typename PassedPixelsAccessor, typename InCombinedSamplerAccessor>
19+
inline void alpha_test(
20+
NBL_REF_ARG(PassedPixelsAccessor) passedPixelsAccessor,
21+
NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor,
22+
NBL_CONST_REF_ARG(uint16_t3) inDim,
23+
NBL_CONST_REF_ARG(float32_t) referenceAlpha,
24+
NBL_CONST_REF_ARG(uint16_t3) globalInvocationID,
25+
NBL_CONST_REF_ARG(uint16_t3) workGroupID)
26+
{
27+
if (all(globalInvocationID < inDim))
28+
{
29+
const float32_t alpha = inCombinedSamplerAccessor.get(globalInvocationID, workGroupID.z).a;
30+
if (alpha > referenceAlpha)
31+
{
32+
passedPixelsAccessor.atomicAdd(workGroupID.z, uint32_t(1));
33+
}
34+
}
35+
}
36+
37+
}
38+
}
39+
}
40+
41+
#endif
42+
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_BLIT_COMMON_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_BLIT_COMMON_INCLUDED_
6+
7+
#include <nbl/builtin/hlsl/cpp_compat.hlsl>
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace blit
14+
{
15+
namespace impl
16+
{
17+
18+
template <uint32_t Dimension>
19+
struct dim_to_image_properties { };
20+
21+
template <>
22+
struct dim_to_image_properties<1>
23+
{
24+
using combined_sampler_t = Texture1DArray<float4>;
25+
using image_t = RWTexture1DArray<float4>;
26+
27+
template <typename T>
28+
static vector<T, 2> getIndexCoord(vector<T, 3> coords, uint32_t layer)
29+
{
30+
return vector<T, 2>(coords.x, layer);
31+
}
32+
};
33+
34+
template <>
35+
struct dim_to_image_properties<2>
36+
{
37+
using combined_sampler_t = Texture2DArray<float4>;
38+
using image_t = RWTexture2DArray<float4>;
39+
40+
template <typename T>
41+
static vector<T,3> getIndexCoord(vector<T, 3> coords, uint32_t layer)
42+
{
43+
return vector<T, 3>(coords.xy, layer);
44+
}
45+
};
46+
47+
template <>
48+
struct dim_to_image_properties<3>
49+
{
50+
using combined_sampler_t = Texture3D<float4>;
51+
using image_t = RWTexture3D<float4>;
52+
53+
template <typename T>
54+
static vector<T, 3> getIndexCoord(vector<T, 3> coords, uint32_t layer)
55+
{
56+
return vector<T,3>(coords);
57+
}
58+
};
59+
60+
}
61+
62+
63+
template<
64+
uint32_t _WorkGroupSizeX,
65+
uint32_t _WorkGroupSizeY,
66+
uint32_t _WorkGroupSizeZ,
67+
uint32_t _SMemFloatsPerChannel,
68+
uint32_t _BlitOutChannelCount,
69+
uint32_t _BlitDimCount,
70+
uint32_t _AlphaBinCount>
71+
struct consteval_parameters_t
72+
{
73+
NBL_CONSTEXPR_STATIC_INLINE uint32_t SMemFloatsPerChannel = _SMemFloatsPerChannel;
74+
NBL_CONSTEXPR_STATIC_INLINE uint32_t BlitOutChannelCount = _BlitOutChannelCount;
75+
NBL_CONSTEXPR_STATIC_INLINE uint32_t BlitDimCount = _BlitDimCount;
76+
NBL_CONSTEXPR_STATIC_INLINE uint32_t AlphaBinCount = _AlphaBinCount;
77+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkGroupSizeX = _WorkGroupSizeX;
78+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkGroupSizeY = _WorkGroupSizeY;
79+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkGroupSizeZ = _WorkGroupSizeZ;
80+
NBL_CONSTEXPR_STATIC_INLINE uint32_t WorkGroupSize = WorkGroupSizeX * WorkGroupSizeY * WorkGroupSizeZ;
81+
};
82+
83+
}
84+
}
85+
}
86+
87+
#endif
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_BLIT_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_BLIT_INCLUDED_
6+
7+
8+
#include <nbl/builtin/hlsl/ndarray_addressing.hlsl>
9+
#include <nbl/builtin/hlsl/blit/parameters.hlsl>
10+
#include <nbl/builtin/hlsl/blit/common.hlsl>
11+
12+
13+
namespace nbl
14+
{
15+
namespace hlsl
16+
{
17+
namespace blit
18+
{
19+
20+
template <typename ConstevalParameters>
21+
struct compute_blit_t
22+
{
23+
float32_t3 scale;
24+
float32_t3 negativeSupport;
25+
uint32_t kernelWeightsOffsetY;
26+
uint32_t kernelWeightsOffsetZ;
27+
uint32_t inPixelCount;
28+
uint32_t outPixelCount;
29+
uint16_t3 outputTexelsPerWG;
30+
uint16_t3 inDims;
31+
uint16_t3 outDims;
32+
uint16_t3 windowDims;
33+
uint16_t3 phaseCount;
34+
uint16_t3 preloadRegion;
35+
uint16_t3 iterationRegionXPrefixProducts;
36+
uint16_t3 iterationRegionYPrefixProducts;
37+
uint16_t3 iterationRegionZPrefixProducts;
38+
uint16_t secondScratchOffset;
39+
40+
static compute_blit_t create(NBL_CONST_REF_ARG(parameters_t) params)
41+
{
42+
compute_blit_t compute_blit;
43+
44+
compute_blit.scale = params.fScale;
45+
compute_blit.negativeSupport = params.negativeSupport;
46+
compute_blit.kernelWeightsOffsetY = params.kernelWeightsOffsetY;
47+
compute_blit.kernelWeightsOffsetZ = params.kernelWeightsOffsetZ;
48+
compute_blit.inPixelCount = params.inPixelCount;
49+
compute_blit.outPixelCount = params.outPixelCount;
50+
compute_blit.outputTexelsPerWG = params.getOutputTexelsPerWG();
51+
compute_blit.inDims = params.inputDims;
52+
compute_blit.outDims = params.outputDims;
53+
compute_blit.windowDims = params.windowDims;
54+
compute_blit.phaseCount = params.phaseCount;
55+
compute_blit.preloadRegion = params.preloadRegion;
56+
compute_blit.iterationRegionXPrefixProducts = params.iterationRegionXPrefixProducts;
57+
compute_blit.iterationRegionYPrefixProducts = params.iterationRegionYPrefixProducts;
58+
compute_blit.iterationRegionZPrefixProducts = params.iterationRegionZPrefixProducts;
59+
compute_blit.secondScratchOffset = params.secondScratchOffset;
60+
61+
return compute_blit;
62+
}
63+
64+
template <
65+
typename InCombinedSamplerAccessor,
66+
typename OutImageAccessor,
67+
typename KernelWeightsAccessor,
68+
typename HistogramAccessor,
69+
typename SharedAccessor>
70+
void execute(
71+
NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor,
72+
NBL_REF_ARG(OutImageAccessor) outImageAccessor,
73+
NBL_CONST_REF_ARG(KernelWeightsAccessor) kernelWeightsAccessor,
74+
NBL_REF_ARG(HistogramAccessor) histogramAccessor,
75+
NBL_REF_ARG(SharedAccessor) sharedAccessor,
76+
uint16_t3 workGroupID,
77+
uint16_t localInvocationIndex)
78+
{
79+
const float3 halfScale = scale * float3(0.5f, 0.5f, 0.5f);
80+
const uint32_t3 minOutputPixel = workGroupID * outputTexelsPerWG;
81+
const float3 minOutputPixelCenterOfWG = float3(minOutputPixel)*scale + halfScale;
82+
// this can be negative, in which case HW sampler takes care of wrapping for us
83+
const int32_t3 regionStartCoord = int32_t3(ceil(minOutputPixelCenterOfWG - float3(0.5f, 0.5f, 0.5f) + negativeSupport));
84+
85+
const uint32_t virtualInvocations = preloadRegion.x * preloadRegion.y * preloadRegion.z;
86+
for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < virtualInvocations; virtualInvocation += ConstevalParameters::WorkGroupSize)
87+
{
88+
const int32_t3 inputPixelCoord = regionStartCoord + int32_t3(ndarray_addressing::snakeCurveInverse(virtualInvocation, preloadRegion));
89+
float32_t3 inputTexCoord = (inputPixelCoord + float32_t3(0.5f, 0.5f, 0.5f)) / inDims;
90+
const float4 loadedData = inCombinedSamplerAccessor.get(inputTexCoord, workGroupID.z);
91+
92+
for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
93+
sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + virtualInvocation, loadedData[ch]);
94+
}
95+
GroupMemoryBarrierWithGroupSync();
96+
97+
const uint32_t3 iterationRegionPrefixProducts[3] = {iterationRegionXPrefixProducts, iterationRegionYPrefixProducts, iterationRegionZPrefixProducts};
98+
99+
uint32_t readScratchOffset = 0;
100+
uint32_t writeScratchOffset = secondScratchOffset;
101+
for (uint32_t axis = 0; axis < ConstevalParameters::BlitDimCount; ++axis)
102+
{
103+
for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < iterationRegionPrefixProducts[axis].z; virtualInvocation += ConstevalParameters::WorkGroupSize)
104+
{
105+
const uint32_t3 virtualInvocationID = ndarray_addressing::snakeCurveInverse(virtualInvocation, iterationRegionPrefixProducts[axis].xy);
106+
107+
uint32_t outputPixel = virtualInvocationID.x;
108+
if (axis == 2)
109+
outputPixel = virtualInvocationID.z;
110+
outputPixel += minOutputPixel[axis];
111+
112+
if (outputPixel >= outDims[axis])
113+
break;
114+
115+
const int32_t minKernelWindow = int32_t(ceil((outputPixel + 0.5f) * scale[axis] - 0.5f + negativeSupport[axis]));
116+
117+
// Combined stride for the two non-blitting dimensions, tightly coupled and experimentally derived with/by `iterationRegionPrefixProducts` above and the general order of iteration we use to avoid
118+
// read bank conflicts.
119+
uint32_t combinedStride;
120+
{
121+
if (axis == 0)
122+
combinedStride = virtualInvocationID.z * preloadRegion.y + virtualInvocationID.y;
123+
else if (axis == 1)
124+
combinedStride = virtualInvocationID.z * outputTexelsPerWG.x + virtualInvocationID.y;
125+
else if (axis == 2)
126+
combinedStride = virtualInvocationID.y * outputTexelsPerWG.y + virtualInvocationID.x;
127+
}
128+
129+
uint32_t offset = readScratchOffset + (minKernelWindow - regionStartCoord[axis]) + combinedStride*preloadRegion[axis];
130+
const uint32_t windowPhase = outputPixel % phaseCount[axis];
131+
132+
uint32_t kernelWeightIndex;
133+
if (axis == 0)
134+
kernelWeightIndex = windowPhase * windowDims.x;
135+
else if (axis == 1)
136+
kernelWeightIndex = kernelWeightsOffsetY + windowPhase * windowDims.y;
137+
else if (axis == 2)
138+
kernelWeightIndex = kernelWeightsOffsetZ + windowPhase * windowDims.z;
139+
140+
float4 kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex);
141+
142+
float4 accum = float4(0.f, 0.f, 0.f, 0.f);
143+
for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
144+
accum[ch] = sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
145+
146+
for (uint32_t i = 1; i < windowDims[axis]; ++i)
147+
{
148+
kernelWeightIndex++;
149+
offset++;
150+
151+
kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex);
152+
for (uint ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
153+
accum[ch] += sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
154+
}
155+
156+
const bool lastPass = (axis == (ConstevalParameters::BlitDimCount - 1));
157+
if (lastPass)
158+
{
159+
// Tightly coupled with iteration order (`iterationRegionPrefixProducts`)
160+
uint32_t3 outCoord = virtualInvocationID.yxz;
161+
if (axis == 0)
162+
outCoord = virtualInvocationID.xyz;
163+
outCoord += minOutputPixel;
164+
165+
const uint32_t bucketIndex = uint32_t(round(clamp(accum.a, 0, 1) * float(ConstevalParameters::AlphaBinCount-1)));
166+
histogramAccessor.atomicAdd(workGroupID.z, bucketIndex, uint32_t(1));
167+
168+
outImageAccessor.set(outCoord, workGroupID.z, accum);
169+
}
170+
else
171+
{
172+
uint32_t scratchOffset = writeScratchOffset;
173+
if (axis == 0)
174+
scratchOffset += ndarray_addressing::snakeCurve(virtualInvocationID.yxz, uint32_t3(preloadRegion.y, outputTexelsPerWG.x, preloadRegion.z));
175+
else
176+
scratchOffset += writeScratchOffset + ndarray_addressing::snakeCurve(virtualInvocationID.zxy, uint32_t3(preloadRegion.z, outputTexelsPerWG.y, outputTexelsPerWG.x));
177+
178+
for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
179+
sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + scratchOffset, accum[ch]);
180+
}
181+
}
182+
183+
const uint32_t tmp = readScratchOffset;
184+
readScratchOffset = writeScratchOffset;
185+
writeScratchOffset = tmp;
186+
GroupMemoryBarrierWithGroupSync();
187+
}
188+
}
189+
};
190+
191+
}
192+
}
193+
}
194+
195+
#endif

0 commit comments

Comments
 (0)