@@ -58,15 +58,19 @@ struct ArithmeticConfiguration
58
58
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
59
59
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
60
60
61
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t (0x1u) << __ItemsPerVirtualWorkgroupLog2;
63
- NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = __ItemsPerVirtualWorkgroup / ItemsPerInvocation_1;
61
+ // NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62
+ // NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
63
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3 ,uint16_t,
64
+ mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
65
+ SubgroupSize*ItemsPerInvocation_1>::value;
66
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize*ItemsPerInvocation_2,0 >::value;
67
+ NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = LevelInputCount_1 / ItemsPerInvocation_1;
64
68
65
69
// user specified the shared mem size of uint32_ts
66
70
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1 ,uint16_t,
67
71
0 ,
68
72
conditional_value<LevelCount==3 ,uint16_t,
69
- SubgroupSize*ItemsPerInvocation_2+__ItemsPerVirtualWorkgroup ,
73
+ SubgroupSize*ItemsPerInvocation_2+LevelInputCount_1 ,
70
74
SubgroupSize*ItemsPerInvocation_1
71
75
>::value
72
76
>::value;
@@ -78,7 +82,7 @@ struct ArithmeticConfiguration
78
82
79
83
// gets a subgroupID as if each workgroup has (VirtualWorkgroupSize/SubgroupSize) subgroups
80
84
// each subgroup does work (VirtualWorkgroupSize/WorkgroupSize) times, the index denoted by workgroupInVirtualIndex
81
- static uint32_t virtualSubgroupID (const uint32_t subgroupID, const uint32_t workgroupInVirtualIndex)
85
+ static uint16_t virtualSubgroupID (const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
82
86
{
83
87
return workgroupInVirtualIndex * (WorkgroupSize >> SubgroupSizeLog2) + subgroupID;
84
88
}
@@ -87,30 +91,30 @@ struct ArithmeticConfiguration
87
91
// specify the next level to store values for in template param
88
92
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
89
93
template<uint16_t level>
90
- static uint32_t sharedStoreIndex (const uint32_t subgroupID)
94
+ static uint16_t sharedStoreIndex (const uint16_t subgroupID)
91
95
{
92
- uint32_t offsetBySubgroup;
96
+ uint16_t offsetBySubgroup;
93
97
if (level == LevelCount-1 )
94
98
offsetBySubgroup = SubgroupSize;
95
99
else
96
100
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
97
101
98
102
if (level<2 )
99
- return (subgroupID & (ItemsPerInvocation_1-1 )) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
103
+ return (subgroupID & (ItemsPerInvocation_1-uint16_t (1u) )) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
100
104
else
101
- return (subgroupID & (ItemsPerInvocation_2-1 )) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
105
+ return (subgroupID & (ItemsPerInvocation_2-uint16_t (1u) )) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
102
106
}
103
107
104
108
template<uint16_t level>
105
- static uint32_t sharedStoreIndexFromVirtualIndex (const uint32_t subgroupID, const uint32_t workgroupInVirtualIndex)
109
+ static uint16_t sharedStoreIndexFromVirtualIndex (const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
106
110
{
107
- const uint32_t virtualID = virtualSubgroupID (subgroupID, workgroupInVirtualIndex);
111
+ const uint16_t virtualID = virtualSubgroupID (subgroupID, workgroupInVirtualIndex);
108
112
return sharedStoreIndex<level>(virtualID);
109
113
}
110
114
111
115
// get the coalesced index in shared mem at the current level
112
116
template<uint16_t level>
113
- static uint32_t sharedLoadIndex (const uint32_t invocationIndex, const uint32_t component)
117
+ static uint16_t sharedLoadIndex (const uint16_t invocationIndex, const uint16_t component)
114
118
{
115
119
if (level == LevelCount-1 )
116
120
return component * SubgroupSize + invocationIndex;
0 commit comments