@@ -70,10 +70,11 @@ struct ArithmeticConfiguration
70
70
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1 ,uint16_t,
71
71
0 ,
72
72
conditional_value<LevelCount==3 ,uint16_t,
73
- LevelInputCount_2,
73
+ LevelInputCount_2+(SubgroupSize*ItemsPerInvocation_1)- 1 ,
74
74
0
75
75
>::value + LevelInputCount_1
76
76
>::value;
77
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize-1 ,0 >::value;
77
78
78
79
static bool electLast ()
79
80
{
@@ -90,40 +91,42 @@ struct ArithmeticConfiguration
90
91
// get a coalesced index to store for the next level in shared mem, e.g. level 0 -> level 1
91
92
// specify the next level to store values for in template param
92
93
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
93
- template<uint16_t level NBL_FUNC_REQUIRES (level>0 && level<LevelCount)
94
+ template<uint16_t level> // NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
94
95
static uint16_t sharedStoreIndex (const uint16_t virtualSubgroupID)
95
96
{
96
- uint16_t offsetBySubgroup ;
97
+ uint16_t nextLevelInvocationCount ;
97
98
if (level == LevelCount-1 )
98
- offsetBySubgroup = SubgroupSize;
99
+ nextLevelInvocationCount = SubgroupSize;
99
100
else
100
- offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
101
+ nextLevelInvocationCount = __SubgroupsPerVirtualWorkgroup;
101
102
102
103
if (level==2 )
103
- return LevelInputCount_1 + (virtualSubgroupID & (ItemsPerInvocation_2-uint16_t (1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_2);
104
+ return LevelInputCount_1 + ((SubgroupSize- uint16_t (1u))*ItemsPerInvocation_1) + ( virtualSubgroupID & (ItemsPerInvocation_2-uint16_t (1u))) * nextLevelInvocationCount + (virtualSubgroupID/ItemsPerInvocation_2);
104
105
else
105
- return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t (1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_1);
106
+ return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t (1u))) * (nextLevelInvocationCount+__padding) + (virtualSubgroupID/ItemsPerInvocation_1) + virtualSubgroupID/(SubgroupSize* ItemsPerInvocation_1);
106
107
}
107
108
108
- template<uint16_t level NBL_FUNC_REQUIRES (level>0 && level<LevelCount)
109
+ template<uint16_t level> // NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
109
110
static uint16_t sharedStoreIndexFromVirtualIndex (const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
110
111
{
111
112
const uint16_t virtualID = virtualSubgroupID (subgroupID, workgroupInVirtualIndex);
112
113
return sharedStoreIndex<level>(virtualID);
113
114
}
114
115
115
116
// get the coalesced index in shared mem at the current level
116
- template<uint16_t level NBL_FUNC_REQUIRES (level>0 && level<LevelCount)
117
+ template<uint16_t level> // NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
117
118
static uint16_t sharedLoadIndex (const uint16_t invocationIndex, const uint16_t component)
118
119
{
119
- uint16_t smem_offset = 0u;
120
- if (level == 2 )
121
- smem_offset += LevelInputCount_1;
122
-
120
+ uint16_t levelInvocationCount;
123
121
if (level == LevelCount-1 )
124
- return component * SubgroupSize + invocationIndex + smem_offset;
122
+ levelInvocationCount = SubgroupSize;
123
+ else
124
+ levelInvocationCount = __SubgroupsPerVirtualWorkgroup;
125
+
126
+ if (level==2 )
127
+ return LevelInputCount_1 + ((SubgroupSize-uint16_t (1u))*ItemsPerInvocation_1) + component * levelInvocationCount + invocationIndex + invocationIndex/SubgroupSize;
125
128
else
126
- return component * __SubgroupsPerVirtualWorkgroup + invocationIndex;
129
+ return component * (levelInvocationCount+__padding) + invocationIndex + invocationIndex/SubgroupSize ;
127
130
}
128
131
};
129
132
0 commit comments