Skip to content

Commit a639145

Browse files
committed
fixes to 3-level scan and minor stuff
1 parent 49ca655 commit a639145

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ struct ArithmeticConfiguration
6161
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;
6262

6363
// must have at least enough level 0 outputs to feed a single subgroup
64-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
65-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << SubgroupsPerVirtualWorkgroupLog2;
64+
NBL_CONSTEXPR_STATIC_INLINE uint16_t _SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
65+
NBL_CONSTEXPR_STATIC_INLINE uint16_t _SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << _SubgroupsPerVirtualWorkgroupLog2;
6666

6767
using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
6868
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
@@ -83,7 +83,7 @@ struct ArithmeticConfiguration
8383

8484
static uint32_t sharedMemCoalescedIndex(const uint32_t id, const uint32_t itemsPerInvocation)
8585
{
86-
return (id & (itemsPerInvocation-1)) * SubgroupsPerVirtualWorkgroup + (id/itemsPerInvocation);
86+
return (id & (itemsPerInvocation-1)) * SubgroupSize + (id/itemsPerInvocation);
8787
}
8888
};
8989

@@ -96,7 +96,6 @@ struct is_configuration<ArithmeticConfiguration<W,S,I> > : bool_constant<true> {
9696
template<typename T>
9797
NBL_CONSTEXPR bool is_configuration_v = is_configuration<T>::value;
9898

99-
10099
}
101100
}
102101
}

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
120120
vector_lv1_t lv1_val;
121121
[unroll]
122122
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
123-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
123+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
124124
lv1_val = reduction1(lv1_val);
125125

126126
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
@@ -176,12 +176,12 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
176176
const uint32_t prevIndex = invocationIndex-1;
177177
[unroll]
178178
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
179-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+prevIndex,lv1_val[i]);
179+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+prevIndex,lv1_val[i]);
180180
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
181181
lv1_val = inclusiveScan1(lv1_val);
182182
[unroll]
183183
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
184-
scratchAccessor.template set<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
184+
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
185185
}
186186
scratchAccessor.workgroupExecutionAndMemoryBarrier();
187187

@@ -258,7 +258,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
258258
vector_lv1_t lv1_val;
259259
[unroll]
260260
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
261-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
261+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
262262
lv1_val = reduction1(lv1_val);
263263
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
264264
{
@@ -275,7 +275,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
275275
vector_lv2_t lv2_val;
276276
[unroll]
277277
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
278-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv2_val[i]);
278+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv2_val[i]);
279279
lv2_val = reduction2(lv2_val);
280280
scratchAccessor.template set<scalar_t>(invocationIndex, lv2_val[Config::ItemsPerInvocation_2-1]);
281281
}
@@ -324,15 +324,20 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
324324
scratchAccessor.workgroupExecutionAndMemoryBarrier();
325325

326326
// level 1 scan
327-
const uint32_t lv1_smem_size = Config::SubgroupsPerVirtualWorkgroup*Config::ItemsPerInvocation_1;
327+
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
328328
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
329329
if (glsl::gl_SubgroupID() < lv1_smem_size)
330330
{
331331
vector_lv1_t lv1_val;
332+
const uint32_t prevIndex = invocationIndex-1;
332333
[unroll]
333334
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
334-
scratchAccessor.template get<scalar_t>(i*Config::SubgroupsPerVirtualWorkgroup+invocationIndex,lv1_val[i]);
335+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+prevIndex,lv1_val[i]);
336+
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
335337
lv1_val = inclusiveScan1(lv1_val);
338+
[unroll]
339+
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
340+
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
336341
if (glsl::gl_SubgroupInvocationID()==Config::SubgroupSize-1)
337342
{
338343
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
@@ -351,37 +356,48 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
351356
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
352357
scratchAccessor.template get<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+prevIndex,lv2_val[i]);
353358
lv2_val[0] = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val[0], bool(invocationIndex));
354-
vector_lv2_t shiftedScan = inclusiveScan2(lv2_val);
355-
356-
// combine with level 1, only last element of each
359+
lv2_val = inclusiveScan2(lv2_val);
357360
[unroll]
358-
for (uint32_t i = 0; i < Config::SubgroupsPerVirtualWorkgroup; i++)
359-
{
360-
scalar_t last_val;
361-
scratchAccessor.template get<scalar_t>((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i),last_val);
362-
scalar_t val = hlsl::mix(hlsl::promote<vector_lv2_t>(BinOp::identity), lv2_val, bool(i));
363-
val = binop(last_val, shiftedScan[Config::ItemsPerInvocation_2-1]);
364-
scratchAccessor.template set<scalar_t>((Config::ItemsPerInvocation_1-1)*Config::SubgroupsPerVirtualWorkgroup+(Config::SubgroupsPerVirtualWorkgroup-1-i), last_val);
365-
}
361+
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
362+
scratchAccessor.template set<scalar_t>(lv1_smem_size+i*Config::SubgroupSize+invocationIndex,lv2_val[i]);
366363
}
367364
scratchAccessor.workgroupExecutionAndMemoryBarrier();
368365

366+
// combine with level 1
367+
if (glsl::gl_SubgroupID() < lv1_smem_size)
368+
{
369+
vector_lv1_t lv1_val;
370+
[unroll]
371+
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
372+
scratchAccessor.template get<scalar_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
373+
374+
scalar_t lv2_scan;
375+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(glsl::gl_SubgroupID(), Config::ItemsPerInvocation_2); // (glsl::gl_SubgroupID() & (Config::ItemsPerInvocation_2-1)) * Config::SubgroupSize + (glsl::gl_SubgroupID()/Config::ItemsPerInvocation_2);
376+
scratchAccessor.template set<scalar_t>(lv1_smem_size+bankedIndex, lv2_scan);
377+
378+
[unroll]
379+
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
380+
scratchAccessor.template set<scalar_t>(i*Config::SubgroupSize+invocationIndex, binop(lv1_val[i],lv2_scan));
381+
}
382+
369383
// combine with level 0
370384
[unroll]
371385
for (uint32_t idx = 0, virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
372386
{
373387
vector_lv0_t value;
374388
dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
375389

376-
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx); // idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID();
377-
const scalar_t left;
378-
scratchAccessor.template get<scalar_t>(virtualSubgroupID, left);
390+
const uint32_t virtualSubgroupID = Config::virtualSubgroupID(glsl::gl_SubgroupID(), idx);
391+
const uint32_t bankedIndex = Config::sharedMemCoalescedIndex(virtualSubgroupID, Config::ItemsPerInvocation_1);
392+
scalar_t left;
393+
scratchAccessor.template get<scalar_t>(bankedIndex,left);
379394
if (Exclusive)
380395
{
381396
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
382397
[unroll]
383-
for (uint32_t i = 0; i < Config::ItemsPerInvocation_0; i++)
384-
value[Config::ItemsPerInvocation_0-i-1] = binop(left, hlsl::mix(value[Config::ItemsPerInvocation_0-i-2], left_last_elem, (Config::ItemsPerInvocation_0-i-1==0)));
398+
for (uint32_t i = Config::ItemsPerInvocation_0-1; i > 0; i--)
399+
value[i] = binop(left, value[i-1]);
400+
value[0] = binop(left, left_last_elem);
385401
}
386402
else
387403
{

0 commit comments

Comments
 (0)