1
+ // Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2
+ // This file is part of the "Nabla Engine".
3
+ // For conditions of distribution and use, see copyright notice in nabla.h
4
+ #ifndef _NBL_BUILTIN_HLSL_BLIT_INCLUDED_
5
+ #define _NBL_BUILTIN_HLSL_BLIT_INCLUDED_
6
+
7
+
8
+ #include <nbl/builtin/hlsl/ndarray_addressing.hlsl>
9
+ #include <nbl/builtin/hlsl/blit/parameters.hlsl>
10
+ #include <nbl/builtin/hlsl/blit/common.hlsl>
11
+
12
+
13
+ namespace nbl
14
+ {
15
+ namespace hlsl
16
+ {
17
+ namespace blit
18
+ {
19
+
20
+ template <typename ConstevalParameters>
21
+ struct compute_blit_t
22
+ {
23
+ float32_t3 scale;
24
+ float32_t3 negativeSupport;
25
+ uint32_t kernelWeightsOffsetY;
26
+ uint32_t kernelWeightsOffsetZ;
27
+ uint32_t inPixelCount;
28
+ uint32_t outPixelCount;
29
+ uint16_t3 outputTexelsPerWG;
30
+ uint16_t3 inDims;
31
+ uint16_t3 outDims;
32
+ uint16_t3 windowDims;
33
+ uint16_t3 phaseCount;
34
+ uint16_t3 preloadRegion;
35
+ uint16_t3 iterationRegionXPrefixProducts;
36
+ uint16_t3 iterationRegionYPrefixProducts;
37
+ uint16_t3 iterationRegionZPrefixProducts;
38
+ uint16_t secondScratchOffset;
39
+
40
+ static compute_blit_t create (NBL_CONST_REF_ARG (parameters_t) params)
41
+ {
42
+ compute_blit_t compute_blit;
43
+
44
+ compute_blit.scale = params.fScale;
45
+ compute_blit.negativeSupport = params.negativeSupport;
46
+ compute_blit.kernelWeightsOffsetY = params.kernelWeightsOffsetY;
47
+ compute_blit.kernelWeightsOffsetZ = params.kernelWeightsOffsetZ;
48
+ compute_blit.inPixelCount = params.inPixelCount;
49
+ compute_blit.outPixelCount = params.outPixelCount;
50
+ compute_blit.outputTexelsPerWG = params.getOutputTexelsPerWG ();
51
+ compute_blit.inDims = params.inputDims;
52
+ compute_blit.outDims = params.outputDims;
53
+ compute_blit.windowDims = params.windowDims;
54
+ compute_blit.phaseCount = params.phaseCount;
55
+ compute_blit.preloadRegion = params.preloadRegion;
56
+ compute_blit.iterationRegionXPrefixProducts = params.iterationRegionXPrefixProducts;
57
+ compute_blit.iterationRegionYPrefixProducts = params.iterationRegionYPrefixProducts;
58
+ compute_blit.iterationRegionZPrefixProducts = params.iterationRegionZPrefixProducts;
59
+ compute_blit.secondScratchOffset = params.secondScratchOffset;
60
+
61
+ return compute_blit;
62
+ }
63
+
64
+ template <
65
+ typename InCombinedSamplerAccessor,
66
+ typename OutImageAccessor,
67
+ typename KernelWeightsAccessor,
68
+ typename HistogramAccessor,
69
+ typename SharedAccessor>
70
+ void execute (
71
+ NBL_CONST_REF_ARG (InCombinedSamplerAccessor) inCombinedSamplerAccessor,
72
+ NBL_REF_ARG (OutImageAccessor) outImageAccessor,
73
+ NBL_CONST_REF_ARG (KernelWeightsAccessor) kernelWeightsAccessor,
74
+ NBL_REF_ARG (HistogramAccessor) histogramAccessor,
75
+ NBL_REF_ARG (SharedAccessor) sharedAccessor,
76
+ uint16_t3 workGroupID,
77
+ uint16_t localInvocationIndex)
78
+ {
79
+ const float3 halfScale = scale * float3 (0.5f , 0.5f , 0.5f );
80
+ const uint32_t3 minOutputPixel = workGroupID * outputTexelsPerWG;
81
+ const float3 minOutputPixelCenterOfWG = float3 (minOutputPixel)*scale + halfScale;
82
+ // this can be negative, in which case HW sampler takes care of wrapping for us
83
+ const int32_t3 regionStartCoord = int32_t3 (ceil (minOutputPixelCenterOfWG - float3 (0.5f , 0.5f , 0.5f ) + negativeSupport));
84
+
85
+ const uint32_t virtualInvocations = preloadRegion.x * preloadRegion.y * preloadRegion.z;
86
+ for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < virtualInvocations; virtualInvocation += ConstevalParameters::WorkGroupSize)
87
+ {
88
+ const int32_t3 inputPixelCoord = regionStartCoord + int32_t3 (ndarray_addressing::snakeCurveInverse (virtualInvocation, preloadRegion));
89
+ float32_t3 inputTexCoord = (inputPixelCoord + float32_t3 (0.5f , 0.5f , 0.5f )) / inDims;
90
+ const float4 loadedData = inCombinedSamplerAccessor.get (inputTexCoord, workGroupID.z);
91
+
92
+ for (uint32_t ch = 0 ; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
93
+ sharedAccessor.set (ch * ConstevalParameters::SMemFloatsPerChannel + virtualInvocation, loadedData[ch]);
94
+ }
95
+ GroupMemoryBarrierWithGroupSync ();
96
+
97
+ const uint32_t3 iterationRegionPrefixProducts[3 ] = {iterationRegionXPrefixProducts, iterationRegionYPrefixProducts, iterationRegionZPrefixProducts};
98
+
99
+ uint32_t readScratchOffset = 0 ;
100
+ uint32_t writeScratchOffset = secondScratchOffset;
101
+ for (uint32_t axis = 0 ; axis < ConstevalParameters::BlitDimCount; ++axis)
102
+ {
103
+ for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < iterationRegionPrefixProducts[axis].z; virtualInvocation += ConstevalParameters::WorkGroupSize)
104
+ {
105
+ const uint32_t3 virtualInvocationID = ndarray_addressing::snakeCurveInverse (virtualInvocation, iterationRegionPrefixProducts[axis].xy);
106
+
107
+ uint32_t outputPixel = virtualInvocationID.x;
108
+ if (axis == 2 )
109
+ outputPixel = virtualInvocationID.z;
110
+ outputPixel += minOutputPixel[axis];
111
+
112
+ if (outputPixel >= outDims[axis])
113
+ break ;
114
+
115
+ const int32_t minKernelWindow = int32_t (ceil ((outputPixel + 0.5f ) * scale[axis] - 0.5f + negativeSupport[axis]));
116
+
117
+ // Combined stride for the two non-blitting dimensions, tightly coupled and experimentally derived with/by `iterationRegionPrefixProducts` above and the general order of iteration we use to avoid
118
+ // read bank conflicts.
119
+ uint32_t combinedStride;
120
+ {
121
+ if (axis == 0 )
122
+ combinedStride = virtualInvocationID.z * preloadRegion.y + virtualInvocationID.y;
123
+ else if (axis == 1 )
124
+ combinedStride = virtualInvocationID.z * outputTexelsPerWG.x + virtualInvocationID.y;
125
+ else if (axis == 2 )
126
+ combinedStride = virtualInvocationID.y * outputTexelsPerWG.y + virtualInvocationID.x;
127
+ }
128
+
129
+ uint32_t offset = readScratchOffset + (minKernelWindow - regionStartCoord[axis]) + combinedStride*preloadRegion[axis];
130
+ const uint32_t windowPhase = outputPixel % phaseCount[axis];
131
+
132
+ uint32_t kernelWeightIndex;
133
+ if (axis == 0 )
134
+ kernelWeightIndex = windowPhase * windowDims.x;
135
+ else if (axis == 1 )
136
+ kernelWeightIndex = kernelWeightsOffsetY + windowPhase * windowDims.y;
137
+ else if (axis == 2 )
138
+ kernelWeightIndex = kernelWeightsOffsetZ + windowPhase * windowDims.z;
139
+
140
+ float4 kernelWeight = kernelWeightsAccessor.get (kernelWeightIndex);
141
+
142
+ float4 accum = float4 (0.f , 0.f , 0.f , 0.f );
143
+ for (uint32_t ch = 0 ; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
144
+ accum[ch] = sharedAccessor.get (ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
145
+
146
+ for (uint32_t i = 1 ; i < windowDims[axis]; ++i)
147
+ {
148
+ kernelWeightIndex++;
149
+ offset++;
150
+
151
+ kernelWeight = kernelWeightsAccessor.get (kernelWeightIndex);
152
+ for (uint ch = 0 ; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
153
+ accum[ch] += sharedAccessor.get (ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch];
154
+ }
155
+
156
+ const bool lastPass = (axis == (ConstevalParameters::BlitDimCount - 1 ));
157
+ if (lastPass)
158
+ {
159
+ // Tightly coupled with iteration order (`iterationRegionPrefixProducts`)
160
+ uint32_t3 outCoord = virtualInvocationID.yxz;
161
+ if (axis == 0 )
162
+ outCoord = virtualInvocationID.xyz;
163
+ outCoord += minOutputPixel;
164
+
165
+ const uint32_t bucketIndex = uint32_t (round (clamp (accum.a, 0 , 1 ) * float (ConstevalParameters::AlphaBinCount-1 )));
166
+ histogramAccessor.atomicAdd (workGroupID.z, bucketIndex, uint32_t (1 ));
167
+
168
+ outImageAccessor.set (outCoord, workGroupID.z, accum);
169
+ }
170
+ else
171
+ {
172
+ uint32_t scratchOffset = writeScratchOffset;
173
+ if (axis == 0 )
174
+ scratchOffset += ndarray_addressing::snakeCurve (virtualInvocationID.yxz, uint32_t3 (preloadRegion.y, outputTexelsPerWG.x, preloadRegion.z));
175
+ else
176
+ scratchOffset += writeScratchOffset + ndarray_addressing::snakeCurve (virtualInvocationID.zxy, uint32_t3 (preloadRegion.z, outputTexelsPerWG.y, outputTexelsPerWG.x));
177
+
178
+ for (uint32_t ch = 0 ; ch < ConstevalParameters::BlitOutChannelCount; ++ch)
179
+ sharedAccessor.set (ch * ConstevalParameters::SMemFloatsPerChannel + scratchOffset, accum[ch]);
180
+ }
181
+ }
182
+
183
+ const uint32_t tmp = readScratchOffset;
184
+ readScratchOffset = writeScratchOffset;
185
+ writeScratchOffset = tmp;
186
+ GroupMemoryBarrierWithGroupSync ();
187
+ }
188
+ }
189
+ };
190
+
191
+ }
192
+ }
193
+ }
194
+
195
+ #endif
0 commit comments