Skip to content

Commit c483941

Browse files
committed
share level 0 scan between 2-level and 3-level scans (and reduce)
1 parent 472aa0b commit c483941

File tree

2 files changed

+40
-55
lines changed

2 files changed

+40
-55
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct ArithmeticConfiguration
4646
using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
4747
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
4848
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value;
49-
static_assert(VirtualWorkgropupSize<=WorkgroupSize*SubgroupSize)
49+
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
5050

5151
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
5252
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << __SubgroupsPerVirtualWorkgroupLog2;

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,17 @@ struct reduce<Config, BinOp, 2, device_capabilities>
8585
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
8686
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
8787

88-
template<class DataAccessor, class ScratchAccessor>
89-
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
88+
template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t>
89+
static void __doLevel0(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
9090
{
91-
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
92-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
93-
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
94-
BinOp binop;
95-
9691
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
9792
// level 0 scan
98-
subgroup2::reduction<params_lv0_t> reduction0;
93+
subgroup2::reduction<Params> reduction0;
9994
[unroll]
10095
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
10196
{
102-
vector_lv0_t scan_local;
103-
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
97+
vector_t scan_local;
98+
dataAccessor.template get<vector_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
10499
scan_local = reduction0(scan_local);
105100
if (Config::electLast())
106101
{
@@ -109,7 +104,19 @@ struct reduce<Config, BinOp, 2, device_capabilities>
109104
}
110105
}
111106
scratchAccessor.workgroupExecutionAndMemoryBarrier();
107+
}
108+
109+
template<class DataAccessor, class ScratchAccessor>
110+
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
111+
{
112+
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
113+
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
114+
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
115+
BinOp binop;
112116

117+
__doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
118+
119+
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
113120
// level 1 scan
114121
subgroup2::reduction<params_lv1_t> reduction1;
115122
if (glsl::gl_SubgroupID() == 0)
@@ -138,32 +145,39 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
138145
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
139146
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
140147

141-
template<class DataAccessor, class ScratchAccessor>
142-
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
148+
template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t>
149+
static void __doLevel0(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
143150
{
144-
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
145-
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
146-
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
147-
BinOp binop;
148-
149151
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
150-
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
152+
subgroup2::inclusive_scan<Params> inclusiveScan0;
151153
// level 0 scan
152154
[unroll]
153155
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
154156
{
155-
vector_lv0_t value;
156-
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
157+
vector_t value;
158+
dataAccessor.template get<vector_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
157159
value = inclusiveScan0(value);
158-
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
160+
dataAccessor.template set<vector_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
159161
if (Config::electLast())
160162
{
161163
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
162164
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
163165
}
164166
}
165167
scratchAccessor.workgroupExecutionAndMemoryBarrier();
168+
}
166169

170+
template<class DataAccessor, class ScratchAccessor>
171+
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
172+
{
173+
using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
174+
using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
175+
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
176+
BinOp binop;
177+
178+
__doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
179+
180+
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
167181
// level 1 scan
168182
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
169183
if (glsl::gl_SubgroupID() == 0)
@@ -228,23 +242,9 @@ struct reduce<Config, BinOp, 3, device_capabilities>
228242
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
229243
BinOp binop;
230244

231-
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
232-
// level 0 scan
233-
subgroup2::reduction<params_lv0_t> reduction0;
234-
[unroll]
235-
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
236-
{
237-
vector_lv0_t scan_local;
238-
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
239-
scan_local = reduction0(scan_local);
240-
if (Config::electLast())
241-
{
242-
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
243-
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
244-
}
245-
}
246-
scratchAccessor.workgroupExecutionAndMemoryBarrier();
245+
reduce<Config, BinOp, 2, device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
247246

247+
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
248248
// level 1 scan
249249
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
250250
subgroup2::reduction<params_lv1_t> reduction1;
@@ -300,24 +300,9 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
300300
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
301301
BinOp binop;
302302

303-
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
304-
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
305-
// level 0 scan
306-
[unroll]
307-
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
308-
{
309-
vector_lv0_t value;
310-
dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
311-
value = inclusiveScan0(value);
312-
dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
313-
if (Config::electLast())
314-
{
315-
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
316-
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
317-
}
318-
}
319-
scratchAccessor.workgroupExecutionAndMemoryBarrier();
303+
scan<Config, BinOp, Exclusive, 2, device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
320304

305+
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
321306
// level 1 scan
322307
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
323308
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;

0 commit comments

Comments
 (0)