Skip to content

Commit 472aa0b

Browse files
committed
more fixes to indexing
1 parent b062ede commit 472aa0b

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ struct ArithmeticConfiguration
4646
using virtual_wg_t = impl::virtual_wg_size_log2<WorkgroupSizeLog2, SubgroupSizeLog2>;
4747
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = virtual_wg_t::levels;
4848
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << virtual_wg_t::value;
49+
static_assert(VirtualWorkgropupSize<=WorkgroupSize*SubgroupSize)
50+
51+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
52+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = uint16_t(0x1u) << __SubgroupsPerVirtualWorkgroupLog2;
53+
4954
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation, WorkgroupSizeLog2, SubgroupSizeLog2>;
5055
// NBL_CONSTEXPR_STATIC_INLINE uint32_t2 ItemsPerInvocation; TODO? doesn't allow inline definitions for uint32_t2 for some reason, uint32_t[2] as well ; declaring out of line results in not constant expression
5156
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
@@ -74,10 +79,16 @@ struct ArithmeticConfiguration
7479
template<uint16_t level>
7580
static uint32_t sharedStoreIndex(const uint32_t subgroupID)
7681
{
82+
uint32_t offsetBySubgroup;
83+
if (level == LevelCount-1)
84+
offsetBySubgroup = SubgroupSize;
85+
else
86+
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
87+
7788
if (level<2)
78-
return (subgroupID & (ItemsPerInvocation_1-1)) * SubgroupSize + (subgroupID/ItemsPerInvocation_1);
89+
return (subgroupID & (ItemsPerInvocation_1-1)) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
7990
else
80-
return (subgroupID & (ItemsPerInvocation_2-1)) * SubgroupSize + (subgroupID/ItemsPerInvocation_2);
91+
return (subgroupID & (ItemsPerInvocation_2-1)) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
8192
}
8293

8394
template<uint16_t level>

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
321321
// level 1 scan
322322
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
323323
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
324-
if (glsl::gl_SubgroupID() < lv1_smem_size)
324+
if (glsl::gl_SubgroupID() < Config::SubgroupsSize*Config::ItemsPerInvocation_2)
325325
{
326326
vector_lv1_t lv1_val;
327327
const uint32_t prevIndex = invocationIndex-1;

0 commit comments

Comments
 (0)