@@ -36,6 +36,8 @@ struct items_per_invocation
36
36
NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
37
37
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t (0x1u) << conditional_value<VirtualWorkgroup::levels==3 , uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2 >, ItemsPerInvocationProductLog2>::value;
38
38
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t (0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2 ,0 >;
39
+
40
+ using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >;
39
41
};
40
42
}
41
43
@@ -53,26 +55,24 @@ struct ArithmeticConfiguration
53
55
static_assert (VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
54
56
55
57
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
56
- using ItemsPerInvocation = tuple<integral_constant<uint16_t,items_per_invoc_t::value0>,integral_constant<uint16_t,items_per_invoc_t::value1>,integral_constant<uint16_t,items_per_invoc_t::value2> >;
57
58
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
58
59
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
59
60
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
61
+ static_assert (ItemsPerInvocation_2<=4 , "4 level scan would have been needed with this config!" );
60
62
61
- // NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62
- // NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
63
63
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3 ,uint16_t,
64
64
mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
65
65
SubgroupSize*ItemsPerInvocation_1>::value;
66
66
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3 ,uint16_t,SubgroupSize*ItemsPerInvocation_2,0 >::value;
67
67
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = LevelInputCount_1 / ItemsPerInvocation_1;
68
68
69
- // user specified the shared mem size of uint32_ts
69
+ // user specified the shared mem size of Scalars
70
70
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1 ,uint16_t,
71
71
0 ,
72
72
conditional_value<LevelCount==3 ,uint16_t,
73
- SubgroupSize*ItemsPerInvocation_2+LevelInputCount_1 ,
74
- SubgroupSize*ItemsPerInvocation_1
75
- >::value
73
+ LevelInputCount_2 ,
74
+ 0
75
+ >::value + LevelInputCount_1
76
76
>::value;
77
77
78
78
static bool electLast ()
@@ -90,30 +90,30 @@ struct ArithmeticConfiguration
90
90
// get a coalesced index to store for the next level in shared mem, e.g. level 0 -> level 1
91
91
// specify the next level to store values for in template param
92
92
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
93
- template<uint16_t level>
94
- static uint16_t sharedStoreIndex (const uint16_t subgroupID )
93
+ template<uint16_t level NBL_FUNC_REQUIRES (level> 0 && level<LevelCount)
94
+ static uint16_t sharedStoreIndex (const uint16_t virtualSubgroupID )
95
95
{
96
96
uint16_t offsetBySubgroup;
97
97
if (level == LevelCount-1 )
98
98
offsetBySubgroup = SubgroupSize;
99
99
else
100
100
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
101
101
102
- if (level< 2 )
103
- return (subgroupID & (ItemsPerInvocation_1 -uint16_t (1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1 );
102
+ if (level== 2 )
103
+ return LevelInputCount_1 + (virtualSubgroupID & (ItemsPerInvocation_2 -uint16_t (1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_2 );
104
104
else
105
- return (subgroupID & (ItemsPerInvocation_2 -uint16_t (1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2 );
105
+ return (virtualSubgroupID & (ItemsPerInvocation_1 -uint16_t (1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_1 );
106
106
}
107
107
108
- template<uint16_t level>
108
+ template<uint16_t level NBL_FUNC_REQUIRES (level> 0 && level<LevelCount)
109
109
static uint16_t sharedStoreIndexFromVirtualIndex (const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
110
110
{
111
111
const uint16_t virtualID = virtualSubgroupID (subgroupID, workgroupInVirtualIndex);
112
112
return sharedStoreIndex<level>(virtualID);
113
113
}
114
114
115
115
// get the coalesced index in shared mem at the current level
116
- template<uint16_t level>
116
+ template<uint16_t level NBL_FUNC_REQUIRES (level> 0 && level<LevelCount)
117
117
static uint16_t sharedLoadIndex (const uint16_t invocationIndex, const uint16_t component)
118
118
{
119
119
if (level == LevelCount-1 )
0 commit comments