Skip to content

Commit 7d77d30

Browse files
committed
change indexing to uint16_t
1 parent f82b405 commit 7d77d30

File tree

2 files changed

+84
-81
lines changed

2 files changed

+84
-81
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,19 @@ struct ArithmeticConfiguration
5858
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
5959
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
6060

61-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
63-
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = __ItemsPerVirtualWorkgroup / ItemsPerInvocation_1;
61+
// NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62+
// NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
63+
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3,uint16_t,
64+
mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
65+
SubgroupSize*ItemsPerInvocation_1>::value;
66+
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value;
67+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = LevelInputCount_1 / ItemsPerInvocation_1;
6468

6569
// user specified the shared mem size of uint32_ts
6670
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
6771
0,
6872
conditional_value<LevelCount==3,uint16_t,
69-
SubgroupSize*ItemsPerInvocation_2+__ItemsPerVirtualWorkgroup,
73+
SubgroupSize*ItemsPerInvocation_2+LevelInputCount_1,
7074
SubgroupSize*ItemsPerInvocation_1
7175
>::value
7276
>::value;
@@ -78,7 +82,7 @@ struct ArithmeticConfiguration
7882

7983
// gets a subgroupID as if each workgroup has (VirtualWorkgroupSize/SubgroupSize) subgroups
8084
// each subgroup does work (VirtualWorkgroupSize/WorkgroupSize) times, the index denoted by workgroupInVirtualIndex
81-
static uint32_t virtualSubgroupID(const uint32_t subgroupID, const uint32_t workgroupInVirtualIndex)
85+
static uint16_t virtualSubgroupID(const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
8286
{
8387
return workgroupInVirtualIndex * (WorkgroupSize >> SubgroupSizeLog2) + subgroupID;
8488
}
@@ -87,30 +91,30 @@ struct ArithmeticConfiguration
8791
// specify the next level to store values for in template param
8892
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
8993
template<uint16_t level>
90-
static uint32_t sharedStoreIndex(const uint32_t subgroupID)
94+
static uint16_t sharedStoreIndex(const uint16_t subgroupID)
9195
{
92-
uint32_t offsetBySubgroup;
96+
uint16_t offsetBySubgroup;
9397
if (level == LevelCount-1)
9498
offsetBySubgroup = SubgroupSize;
9599
else
96100
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
97101

98102
if (level<2)
99-
return (subgroupID & (ItemsPerInvocation_1-1)) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
103+
return (subgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
100104
else
101-
return (subgroupID & (ItemsPerInvocation_2-1)) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
105+
return (subgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
102106
}
103107

104108
template<uint16_t level>
105-
static uint32_t sharedStoreIndexFromVirtualIndex(const uint32_t subgroupID, const uint32_t workgroupInVirtualIndex)
109+
static uint16_t sharedStoreIndexFromVirtualIndex(const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
106110
{
107-
const uint32_t virtualID = virtualSubgroupID(subgroupID, workgroupInVirtualIndex);
111+
const uint16_t virtualID = virtualSubgroupID(subgroupID, workgroupInVirtualIndex);
108112
return sharedStoreIndex<level>(virtualID);
109113
}
110114

111115
// get the coalesced index in shared mem at the current level
112116
template<uint16_t level>
113-
static uint32_t sharedLoadIndex(const uint32_t invocationIndex, const uint32_t component)
117+
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
114118
{
115119
if (level == LevelCount-1)
116120
return component * SubgroupSize + invocationIndex;

0 commit comments

Comments
 (0)