Skip to content

Commit da6c313

Browse files
committed
split out level 0 scans into its own struct
1 parent 37aa99b commit da6c313

File tree

1 file changed

+47
-32
lines changed

1 file changed

+47
-32
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,15 @@ struct scan<Config, BinOp, Exclusive, 1, device_capabilities>
7777
}
7878
};
7979

80-
// 2-level scans
80+
// do level 0 scans for 2- and 3-level scans (same code)
8181
template<class Config, class BinOp, class device_capabilities>
82-
struct reduce<Config, BinOp, 2, device_capabilities>
82+
struct reduce_level0
8383
{
8484
using scalar_t = typename BinOp::type_t;
85-
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
86-
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
85+
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
8786

88-
template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t>
89-
static void __doLevel0(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
87+
template<class DataAccessor, class ScratchAccessor, class Params>
88+
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
9089
{
9190
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
9291
// level 0 scan
@@ -104,7 +103,45 @@ struct reduce<Config, BinOp, 2, device_capabilities>
104103
}
105104
}
106105
scratchAccessor.workgroupExecutionAndMemoryBarrier();
106+
};
107+
};
108+
109+
template<class Config, class BinOp, class device_capabilities>
110+
struct scan_level0
111+
{
112+
using scalar_t = typename BinOp::type_t;
113+
using vector_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
114+
115+
template<class DataAccessor, class ScratchAccessor, class Params>
116+
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
117+
{
118+
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
119+
subgroup2::inclusive_scan<Params> inclusiveScan0;
120+
// level 0 scan
121+
[unroll]
122+
for (uint16_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
123+
{
124+
vector_t value;
125+
dataAccessor.template get<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
126+
value = inclusiveScan0(value);
127+
dataAccessor.template set<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
128+
if (Config::electLast())
129+
{
130+
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()), idx);
131+
scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
132+
}
133+
}
134+
scratchAccessor.workgroupExecutionAndMemoryBarrier();
107135
}
136+
};
137+
138+
// 2-level scans
139+
template<class Config, class BinOp, class device_capabilities>
140+
struct reduce<Config, BinOp, 2, device_capabilities>
141+
{
142+
using scalar_t = typename BinOp::type_t;
143+
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
144+
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
108145

109146
template<class DataAccessor, class ScratchAccessor>
110147
scalar_t __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
@@ -114,7 +151,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
114151
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
115152
BinOp binop;
116153

117-
__doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
154+
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
118155

119156
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
120157
// level 1 scan
@@ -145,28 +182,6 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
145182
using vector_lv0_t = vector<scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
146183
using vector_lv1_t = vector<scalar_t, Config::ItemsPerInvocation_1>;
147184

148-
template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t>
149-
static void __doLevel0(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
150-
{
151-
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
152-
subgroup2::inclusive_scan<Params> inclusiveScan0;
153-
// level 0 scan
154-
[unroll]
155-
for (uint16_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
156-
{
157-
vector_t value;
158-
dataAccessor.template get<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
159-
value = inclusiveScan0(value);
160-
dataAccessor.template set<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
161-
if (Config::electLast())
162-
{
163-
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()), idx);
164-
scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1]); // set last element of subgroup scan (reduction) to level 1 scan
165-
}
166-
}
167-
scratchAccessor.workgroupExecutionAndMemoryBarrier();
168-
}
169-
170185
template<class DataAccessor, class ScratchAccessor>
171186
void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
172187
{
@@ -175,7 +190,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
175190
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
176191
BinOp binop;
177192

178-
__doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
193+
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
179194

180195
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
181196
// level 1 scan
@@ -243,7 +258,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
243258
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
244259
BinOp binop;
245260

246-
reduce<Config, BinOp, 2, device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
261+
reduce_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
247262

248263
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
249264
// level 1 scan
@@ -300,7 +315,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
300315
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
301316
BinOp binop;
302317

303-
scan<Config, BinOp, Exclusive, 2, device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
318+
scan_level0<Config, BinOp, device_capabilities>::template __call<DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
304319

305320
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
306321
// level 1 scan

0 commit comments

Comments
 (0)