Skip to content

Commit 203c03a

Browse files
committed
some indexing fixes for 3-level reduce/scan
1 parent 90d3579 commit 203c03a

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,23 @@ struct ArithmeticConfiguration
5151
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value;
5252
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
5353

54-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
55-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << __SubgroupsPerVirtualWorkgroupLog2;
56-
5754
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
5855
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression
5956
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
6057
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
6158
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
6259
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");
6360

61+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
63+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = __ItemsPerVirtualWorkgroup / ItemsPerInvocation_1;
64+
6465
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
6566
0,
6667
conditional_value<LevelCount==3,uint16_t,
67-
SubgroupSize*ItemsPerInvocation_2,
68-
0
69-
>::value + SubgroupSize*ItemsPerInvocation_1
68+
SubgroupSize*ItemsPerInvocation_2+__ItemsPerVirtualWorkgroup,
69+
SubgroupSize*ItemsPerInvocation_1
70+
>::value
7071
>::value;
7172

7273
static bool electLast()

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
245245

246246
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
247247
// level 1 scan
248-
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
248+
const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup;
249249
subgroup2::reduction<params_lv1_t> reduction1;
250250
if (glsl::gl_SubgroupID() < Config::SubgroupSize*Config::ItemsPerInvocation_2)
251251
{
@@ -303,8 +303,8 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
303303

304304
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
305305
// level 1 scan
306-
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
307-
const uint32_t lv1_num_invoc = Config::SubgroupsSize*Config::ItemsPerInvocation_2;
306+
const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup;
307+
const uint32_t lv1_num_invoc = Config::SubgroupSize*Config::ItemsPerInvocation_2;
308308
subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
309309
if (glsl::gl_SubgroupID() < lv1_num_invoc)
310310
{

0 commit comments

Comments
 (0)