Skip to content

Commit 4622f1f

Browse files
committed
changed template parameters
1 parent f2a281c commit 4622f1f

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,24 @@ namespace hlsl
1919
namespace subgroup2
2020
{
2121

22-
template<typename Config,template<class> class BinOp, typename T, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
22+
template<typename Config, class BinOp, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
2323
struct ArithmeticParams
2424
{
2525
using config_t = Config;
2626
using binop_t = BinOp;
27-
using type_t = T;
27+
using type_t = typename BinOp::type_t;
2828

29+
// static_assert() vector_trait Dimension == ItemsPerInvocation
2930
NBL_CONSTEXPR_STATIC_INLINE int32_t itemsPerInvocation = ItemsPerInvocation;
3031
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
3132
};
3233

3334
template<typename Params>
34-
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
35+
struct reduction : impl::reduction<typename Params::binop_t,Params::UseNativeIntrinsics> {};
3536
template<typename Params>
36-
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
37+
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
3738
template<typename Params>
38-
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
39+
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
3940

4041
}
4142
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,26 @@ namespace subgroup2
1616
namespace impl
1717
{
1818

19-
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
19+
template<class Binop, bool native>
2020
struct inclusive_scan
2121
{
22-
using type_t = typename T;
23-
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
24-
using binop_t = Binop<type_t>;
25-
using binop_par_t = Binop<par_type_t>;
22+
using type_t = typename Binop::type_t;
23+
using scalar_t = typename Binop::scalar_t;
24+
using binop_t = Binop<scalar_t>;
25+
using binop_par_t = Binop<type_t>;
2626
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2727

28-
par_type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
28+
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
2929
{
3030
binop_t binop;
31-
par_type_t retval;
31+
type_t retval;
3232
retval[0] = value[0];
3333
[unroll(ItemsPerInvocation-1)]
3434
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
3535
retval[i] = binop(retval[i-1], value[i]);
3636

3737
exclusive_scan_op_t op;
38-
type_t exclusive = op(retval[ItemsPerInvocation-1]);
38+
scalar_t exclusive = op(retval[ItemsPerInvocation-1]);
3939

4040
[unroll(ItemsPerInvocation)]
4141
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
@@ -44,23 +44,23 @@ struct inclusive_scan
4444
}
4545
};
4646

47-
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
47+
template<class Binop, bool native>
4848
struct exclusive_scan
4949
{
50-
using type_t = typename T;
51-
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
52-
using binop_t = Binop<type_t>;
53-
using binop_par_t = Binop<par_type_t>;
50+
using type_t = typename Binop::type_t;
51+
using scalar_t = typename Binop::scalar_t;
52+
using binop_t = Binop<scalar_t>;
53+
using binop_par_t = Binop<type_t>;
5454
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_par_t, native>;
5555

56-
par_type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
56+
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
5757
{
5858
inclusive_scan_op_t op;
5959
value = op(value);
6060

61-
par_type_t left = glsl::subgroupShuffleUp<par_type_t>(value,1);
61+
type_t left = glsl::subgroupShuffleUp<type_t>(value,1);
6262

63-
par_type_t retval;
63+
type_t retval;
6464
retval[0] = bool(glsl::gl_SubgroupInvocationID()) ? left[ItemsPerInvocation-1] : binop_t::identity;
6565
[unroll(ItemsPerInvocation-1)]
6666
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
@@ -69,21 +69,21 @@ struct exclusive_scan
6969
}
7070
};
7171

72-
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
72+
template<class Binop, bool native>
7373
struct reduction
7474
{
75-
using type_t = typename T;
76-
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
77-
using binop_t = Binop<type_t>;
78-
using binop_par_t = Binop<par_type_t>;
75+
using type_t = typename Binop::type_t;
76+
using scalar_t = typename Binop::scalar_t;
77+
using binop_t = Binop<scalar_t>;
78+
using binop_par_t = Binop<type_t>;
7979
using op_t = subgroup::impl::reduction<binop_par_t, native>;
8080

81-
type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
81+
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
8282
{
8383
binop_t binop;
8484
op_t op;
85-
par_type_t result = op(value);
86-
type_t retval;
85+
type_t result = op(value);
86+
scalar_t retval;
8787
[unroll(ItemsPerInvocation-1)]
8888
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
8989
retval += binop(retval, result[i]);

0 commit comments

Comments
 (0)