Skip to content

Commit 542592f

Browse files
committed
soome changes to arithmetic config
1 parent 507904f commit 542592f

File tree

3 files changed

+23
-27
lines changed

3 files changed

+23
-27
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace impl
1818
template<uint16_t WorkgroupSizeLog2, uint16_t SubgroupSizeLog2>
1919
struct virtual_wg_size_log2
2020
{
21+
static_assert(WorkgroupSizeLog2>=SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize");
22+
static_assert(WorkgroupSizeLog2<=SubgroupSizeLog2+4, "WorkgroupSize cannot be larger than SubgroupSize*16");
2123
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2),uint16_t,conditional_value<(WorkgroupSizeLog2>SubgroupSizeLog2*2+2),uint16_t,3,2>::value,1>::value;
2224
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = mpl::max_v<uint32_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>+SubgroupSizeLog2;
2325
};
@@ -30,6 +32,24 @@ struct items_per_invocation
3032
NBL_CONSTEXPR_STATIC_INLINE uint16_t value1 = uint16_t(0x1u) << conditional_value<VirtualWorkgroup::levels==3, uint16_t,mpl::min_v<uint16_t,ItemsPerInvocationProductLog2,2>, ItemsPerInvocationProductLog2>::value;
3133
NBL_CONSTEXPR_STATIC_INLINE uint16_t value2 = uint16_t(0x1u) << mpl::max_v<int16_t,ItemsPerInvocationProductLog2-2,0>;
3234
};
35+
36+
// explicit specializations for cases that don't fit
37+
#define SPECIALIZE_VIRTUAL_WG_SIZE_CASE(WGLOG2, SGLOG2, LEVELS, VALUE) template<>\
38+
struct virtual_wg_size_log2<WGLOG2, SGLOG2>\
39+
{\
40+
NBL_CONSTEXPR_STATIC_INLINE uint16_t levels = LEVELS;\
41+
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = VALUE;\
42+
};\
43+
44+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(11,4,3,12);
45+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(7,7,1,7);
46+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(6,6,1,6);
47+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(5,5,1,5);
48+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(4,4,1,4);
49+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(3,3,1,3);
50+
SPECIALIZE_VIRTUAL_WG_SIZE_CASE(2,2,1,2);
51+
52+
#undef SPECIALIZE_VIRTUAL_WG_SIZE_CASE
3353
}
3454

3555
template<uint16_t _WorkgroupSizeLog2, uint16_t _SubgroupSizeLog2, uint16_t _ItemsPerInvocation>
@@ -39,7 +59,6 @@ struct ArithmeticConfiguration
3959
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << WorkgroupSizeLog2;
4060
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = _SubgroupSizeLog2;
4161
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;
42-
static_assert(WorkgroupSizeLog2>=_SubgroupSizeLog2, "WorkgroupSize cannot be smaller than SubgroupSize");
4362

4463
// must have at least enough level 0 outputs to feed a single subgroup
4564
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = mpl::max_v<uint16_t, WorkgroupSizeLog2-SubgroupSizeLog2, SubgroupSizeLog2>;
@@ -55,34 +74,11 @@ struct ArithmeticConfiguration
5574
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = items_per_invoc_t::value2;
5675
static_assert(ItemsPerInvocation_1<=4, "3 level scan would have been needed with this config!");
5776

58-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SharedMemSize = conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupsPerVirtualWorkgroup*ItemsPerInvocation_1;
77+
NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementCount = conditional_value<LevelCount==1,uint16_t,0,conditional_value<LevelCount==3,uint16_t,SubgroupSize*ItemsPerInvocation_2,0>::value + SubgroupSize*ItemsPerInvocation_1>::value;
5978
};
6079

61-
// special case when workgroup size 2048 and subgroup size 16 needs 3 levels and virtual workgroup size 4096 to get a full subgroup scan each on level 1 and 2 16x16x16=4096
62-
// specializing with macros because of DXC bug: https://github.com/microsoft/DirectXShaderCom0piler/issues/7007
63-
#define SPECIALIZE_CONFIG_CASE_2048_16(ITEMS_PER_INVOC) template<>\
64-
struct ArithmeticConfiguration<11, 4, ITEMS_PER_INVOC>\
65-
{\
66-
NBL_CONSTEXPR_STATIC_INLINE uint16_t WorkgroupSize = uint16_t(0x1u) << 11u;\
67-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSizeLog2 = uint16_t(4u);\
68-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupSize = uint16_t(0x1u) << SubgroupSizeLog2;\
69-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroupLog2 = 7u;\
70-
NBL_CONSTEXPR_STATIC_INLINE uint16_t SubgroupsPerVirtualWorkgroup = 128u;\
71-
NBL_CONSTEXPR_STATIC_INLINE uint16_t LevelCount = 3u;\
72-
NBL_CONSTEXPR_STATIC_INLINE uint16_t VirtualWorkgroupSize = uint16_t(0x1u) << 4096;\
73-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_0 = ITEMS_PER_INVOC;\
74-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_1 = 1u;\
75-
NBL_CONSTEXPR_STATIC_INLINE uint16_t ItemsPerInvocation_2 = 1u;\
76-
};\
77-
78-
SPECIALIZE_CONFIG_CASE_2048_16(1)
79-
SPECIALIZE_CONFIG_CASE_2048_16(2)
80-
SPECIALIZE_CONFIG_CASE_2048_16(4)
81-
8280
}
8381
}
8482
}
8583

86-
#undef SPECIALIZE_CONFIG_CASE_2048_16
87-
8884
#endif

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "nbl/builtin/hlsl/subgroup2/ballot.hlsl"
1010
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl"
1111
#include "nbl/builtin/hlsl/mpl.hlsl"
12-
#include "nbl/builtin/hlsl/workgroup2/config.hlsl"
12+
#include "nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl"
1313

1414
namespace nbl
1515
{

0 commit comments

Comments
 (0)