Skip to content

Commit abfaf67

Browse files
committed
working subgroup2 template and funcs
1 parent 4622f1f commit abfaf67

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@ namespace hlsl
1919
namespace subgroup2
2020
{
2121

22-
template<typename Config, class BinOp, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
22+
template<typename Config, class BinOp, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(subgroup::is_configuration_v<Config>)
2323
struct ArithmeticParams
2424
{
2525
using config_t = Config;
2626
using binop_t = BinOp;
27-
using type_t = typename BinOp::type_t;
27+
using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type
28+
using type_t = vector<scalar_t, ItemsPerInvocation>;
2829

29-
// static_assert() vector_trait Dimension == ItemsPerInvocation
3030
NBL_CONSTEXPR_STATIC_INLINE int32_t itemsPerInvocation = ItemsPerInvocation;
3131
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
3232
};
3333

3434
template<typename Params>
35-
struct reduction : impl::reduction<typename Params::binop_t,Params::UseNativeIntrinsics> {};
35+
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
3636
template<typename Params>
37-
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
37+
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
3838
template<typename Params>
39-
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
39+
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
4040

4141
}
4242
}
4343
}
4444

45-
#endif
45+
#endif

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,44 +16,46 @@ namespace subgroup2
1616
namespace impl
1717
{
1818

19-
template<class Binop, bool native>
19+
template<class Binop, typename T, bool native>
2020
struct inclusive_scan
2121
{
22-
using type_t = typename Binop::type_t;
23-
using scalar_t = typename Binop::scalar_t;
24-
using binop_t = Binop<scalar_t>;
25-
using binop_par_t = Binop<type_t>;
22+
using type_t = T;
23+
using scalar_t = typename Binop::type_t;
24+
using binop_t = Binop;
2625
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2726

27+
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
28+
2829
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
2930
{
3031
binop_t binop;
3132
type_t retval;
3233
retval[0] = value[0];
33-
[unroll(ItemsPerInvocation-1)]
34+
//[unroll(ItemsPerInvocation-1)]
3435
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
3536
retval[i] = binop(retval[i-1], value[i]);
3637

3738
exclusive_scan_op_t op;
3839
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
3940

40-
[unroll(ItemsPerInvocation)]
41+
//[unroll(ItemsPerInvocation)]
4142
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
4243
retval[i] = binop(retval[i], exclusive);
4344
return retval;
4445
}
4546
};
4647

47-
template<class Binop, bool native>
48+
template<class Binop, typename T, bool native>
4849
struct exclusive_scan
4950
{
50-
using type_t = typename Binop::type_t;
51-
using scalar_t = typename Binop::scalar_t;
52-
using binop_t = Binop<scalar_t>;
53-
using binop_par_t = Binop<type_t>;
54-
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_par_t, native>;
51+
using type_t = T;
52+
using scalar_t = typename Binop::type_t;
53+
using binop_t = Binop;
54+
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, native>;
5555

56-
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
56+
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
57+
58+
type_t operator()(type_t value)
5759
{
5860
inclusive_scan_op_t op;
5961
value = op(value);
@@ -62,32 +64,32 @@ struct exclusive_scan
6264

6365
type_t retval;
6466
retval[0] = bool(glsl::gl_SubgroupInvocationID()) ? left[ItemsPerInvocation-1] : binop_t::identity;
65-
[unroll(ItemsPerInvocation-1)]
67+
//[unroll(ItemsPerInvocation-1)]
6668
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
6769
retval[i] = value[i-1];
6870
return retval;
6971
}
7072
};
7173

72-
template<class Binop, bool native>
74+
template<class Binop, typename T, bool native>
7375
struct reduction
7476
{
75-
using type_t = typename Binop::type_t;
76-
using scalar_t = typename Binop::scalar_t;
77-
using binop_t = Binop<scalar_t>;
78-
using binop_par_t = Binop<type_t>;
79-
using op_t = subgroup::impl::reduction<binop_par_t, native>;
77+
using type_t = T; // TODO? assert scalar_type<T> == scalar_t
78+
using scalar_t = typename Binop::type_t;
79+
using binop_t = Binop;
80+
using op_t = subgroup::impl::reduction<binop_t, native>;
81+
82+
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
8083

8184
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
8285
{
8386
binop_t binop;
8487
op_t op;
85-
type_t result = op(value);
86-
scalar_t retval;
87-
[unroll(ItemsPerInvocation-1)]
88-
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
89-
retval += binop(retval, result[i]);
90-
return retval;
88+
scalar_t retval = value[0];
89+
//[unroll(ItemsPerInvocation-1)]
90+
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
91+
retval += binop(retval, value[i]);
92+
return op(retval);
9193
}
9294
};
9395

@@ -97,4 +99,4 @@ struct reduction
9799
}
98100
}
99101

100-
#endif
102+
#endif

0 commit comments

Comments
 (0)