Skip to content

Commit 37aa99b

Browse files
committed
some adjustments to config and func usages
1 parent 7b15a54 commit 37aa99b

File tree

3 files changed

+23
-25
lines changed

3 files changed

+23
-25
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace hlsl
1717
namespace workgroup2
1818
{
1919

20-
template<class Config, class BinOp, class device_capabilities=void>
20+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
2121
struct reduction
2222
{
2323
using scalar_t = typename BinOp::type_t;
@@ -30,7 +30,7 @@ struct reduction
3030
}
3131
};
3232

33-
template<class Config, class BinOp, class device_capabilities=void>
33+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
3434
struct inclusive_scan
3535
{
3636
using scalar_t = typename BinOp::type_t;
@@ -43,7 +43,7 @@ struct inclusive_scan
4343
}
4444
};
4545

46-
template<class Config, class BinOp, class device_capabilities=void>
46+
template<class Config, class BinOp, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
4747
struct exclusive_scan
4848
{
4949
using scalar_t = typename BinOp::type_t;

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct items_per_invocation
3636
NBL_CONSTEXPR_STATIC_INLINE uint16_t value0 = BaseItemsPerInvocation;
3737
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value;
3838
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>;
39+
40+
using ItemsPerInvocation = tuple<integral_constant<uint16_t,value0>,integral_constant<uint16_t,value1>,integral_constant<uint16_t,value2> >;
3941
};
4042
}
4143

@@ -53,26 +55,24 @@ struct ArithmeticConfiguration
5355
static_assert(VirtualWorkgroupSize<=WorkgroupSize*SubgroupSize);
5456

5557
using items_per_invoc_t = impl::items_per_invocation<virtual_wg_t, _ItemsPerInvocation>;
56-
using ItemsPerInvocation = tuple<integral_constant<uint16_t,items_per_invoc_t::value0>,integral_constant<uint16_t,items_per_invoc_t::value1>,integral_constant<uint16_t,items_per_invoc_t::value2> >;
5758
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = items_per_invoc_t::value0;
5859
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = items_per_invoc_t::value1;
5960
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
61+
static_assert(ItemsPerInvocation_2<=4, "4 level scan would have been needed with this config!");
6062

61-
// NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
62-
// NBL_CONSTEXPR_STATIC_INLINE uint16_t __ItemsPerVirtualWorkgroup = uint16_t(0x1u) << __ItemsPerVirtualWorkgroupLog2;
6363
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_1 = conditional_value<LevelCount==3,uint16_t,
6464
mpl::max_v<uint16_t, (VirtualWorkgroupSize>>SubgroupSizeLog2), SubgroupSize>,
6565
SubgroupSize*ItemsPerInvocation_1>::value;
6666
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelInputCount_2 = conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value;
6767
NBL_CONSTEXPR_STATIC_INLINE uint16_t __SubgroupsPerVirtualWorkgroup = LevelInputCount_1 / ItemsPerInvocation_1;
6868

69-
// user specified the shared mem size of uint32_ts
69+
// user specified the shared mem size of Scalars
7070
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
7171
0,
7272
conditional_value<LevelCount==3,uint16_t,
73-
SubgroupSize*ItemsPerInvocation_2+LevelInputCount_1,
74-
SubgroupSize*ItemsPerInvocation_1
75-
>::value
73+
LevelInputCount_2,
74+
0
75+
>::value + LevelInputCount_1
7676
>::value;
7777

7878
static bool electLast()
@@ -90,30 +90,30 @@ struct ArithmeticConfiguration
9090
// get a coalesced index to store for the next level in shared mem, e.g. level 0 -> level 1
9191
// specify the next level to store values for in template param
9292
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
93-
template<uint16_t level>
94-
static uint16_t sharedStoreIndex(const uint16_t subgroupID)
93+
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
94+
static uint16_t sharedStoreIndex(const uint16_t virtualSubgroupID)
9595
{
9696
uint16_t offsetBySubgroup;
9797
if (level == LevelCount-1)
9898
offsetBySubgroup = SubgroupSize;
9999
else
100100
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
101101

102-
if (level<2)
103-
return (subgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_1);
102+
if (level==2)
103+
return LevelInputCount_1 + (virtualSubgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_2);
104104
else
105-
return (subgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * offsetBySubgroup + (subgroupID/ItemsPerInvocation_2);
105+
return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_1);
106106
}
107107

108-
template<uint16_t level>
108+
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
109109
static uint16_t sharedStoreIndexFromVirtualIndex(const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
110110
{
111111
const uint16_t virtualID = virtualSubgroupID(subgroupID, workgroupInVirtualIndex);
112112
return sharedStoreIndex<level>(virtualID);
113113
}
114114

115115
// get the coalesced index in shared mem at the current level
116-
template<uint16_t level>
116+
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
117117
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
118118
{
119119
if (level == LevelCount-1)

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ struct reduce<Config, BinOp, 3, device_capabilities>
247247

248248
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
249249
// level 1 scan
250-
const uint32_t lv1_smem_size = Config::LevelInputCount_1;
251250
subgroup2::reduction<params_lv1_t> reduction1;
252251
if (glsl::gl_SubgroupID() < Config::LevelInputCount_2)
253252
{
@@ -259,7 +258,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
259258
if (Config::electLast())
260259
{
261260
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()));
262-
scratchAccessor.template set<scalar_t, uint16_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
261+
scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
263262
}
264263
}
265264
scratchAccessor.workgroupExecutionAndMemoryBarrier();
@@ -271,7 +270,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
271270
vector_lv2_t lv2_val;
272271
[unroll]
273272
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
274-
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
273+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
275274
lv2_val = reduction2(lv2_val);
276275
if (Config::electLast())
277276
scratchAccessor.template set<scalar_t, uint16_t>(0, lv2_val[Config::ItemsPerInvocation_2-1]);
@@ -305,7 +304,6 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
305304

306305
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
307306
// level 1 scan
308-
const uint32_t lv1_smem_size = Config::LevelInputCount_1;
309307
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
310308
if (glsl::gl_SubgroupID() < Config::LevelInputCount_2)
311309
{
@@ -320,7 +318,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
320318
if (Config::electLast())
321319
{
322320
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()));
323-
scratchAccessor.template set<scalar_t, uint16_t>(lv1_smem_size+bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
321+
scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
324322
}
325323
}
326324
scratchAccessor.workgroupExecutionAndMemoryBarrier();
@@ -332,11 +330,11 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
332330
vector_lv2_t lv2_val;
333331
[unroll]
334332
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
335-
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
333+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
336334
lv2_val = inclusiveScan2(lv2_val);
337335
[unroll]
338336
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
339-
scratchAccessor.template set<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
337+
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
340338
}
341339
scratchAccessor.workgroupExecutionAndMemoryBarrier();
342340

@@ -351,7 +349,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
351349

352350
scalar_t lv2_scan;
353351
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()-1u));
354-
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+bankedIndex, lv2_scan);
352+
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex, lv2_scan);
355353

356354
[unroll]
357355
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i--)

0 commit comments

Comments
 (0)