Skip to content

Commit e88f51a

Browse files
committed
partial spec for items per invoc =1
1 parent 1478837 commit e88f51a

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability.hlsl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ struct ArithmeticParams
2525
using config_t = Config;
2626
using binop_t = BinOp;
2727
using scalar_t = typename BinOp::type_t; // BinOp should be with scalar type
28-
using type_t = vector<scalar_t, _ItemsPerInvocation>;
28+
using type_t = conditional_t<_ItemsPerInvocation<2, scalar_t, vector<scalar_t, _ItemsPerInvocation> >;
2929

3030
NBL_CONSTEXPR_STATIC_INLINE int32_t ItemsPerInvocation = _ItemsPerInvocation;
3131
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
3232
};
3333

3434
template<typename Params>
35-
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
35+
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3636
template<typename Params>
37-
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
37+
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
3838
template<typename Params>
39-
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::UseNativeIntrinsics> {};
39+
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::ItemsPerInvocation,Params::UseNativeIntrinsics> {};
4040

4141
}
4242
}

include/nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
66

7-
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl"
7+
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl"
88

99
namespace nbl
1010
{
@@ -16,15 +16,15 @@ namespace subgroup2
1616
namespace impl
1717
{
1818

19-
template<class Binop, typename T, bool native>
19+
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
2020
struct inclusive_scan
2121
{
2222
using type_t = T;
2323
using scalar_t = typename Binop::type_t;
2424
using binop_t = Binop;
2525
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
2626

27-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
27+
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
2828

2929
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
3030
{
@@ -45,15 +45,15 @@ struct inclusive_scan
4545
}
4646
};
4747

48-
template<class Binop, typename T, bool native>
48+
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
4949
struct exclusive_scan
5050
{
5151
using type_t = T;
5252
using scalar_t = typename Binop::type_t;
5353
using binop_t = Binop;
54-
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, native>;
54+
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_t, T, ItemsPerInvocation, native>;
5555

56-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
56+
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
5757

5858
type_t operator()(type_t value)
5959
{
@@ -71,15 +71,15 @@ struct exclusive_scan
7171
}
7272
};
7373

74-
template<class Binop, typename T, bool native>
74+
template<class Binop, typename T, uint32_t ItemsPerInvocation, bool native>
7575
struct reduction
7676
{
7777
using type_t = T; // TODO? assert scalar_type<T> == scalar_t
7878
using scalar_t = typename Binop::type_t;
7979
using binop_t = Binop;
8080
using op_t = subgroup::impl::reduction<binop_t, native>;
8181

82-
NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
82+
// NBL_CONSTEXPR_STATIC_INLINE uint32_t ItemsPerInvocation = vector_traits<T>::Dimension;
8383

8484
scalar_t operator()(NBL_CONST_REF_ARG(type_t) value)
8585
{
@@ -93,6 +93,48 @@ struct reduction
9393
}
9494
};
9595

96+
97+
// spec for N=1 uses subgroup funcs
98+
template<class Binop, typename T, bool native>
99+
struct inclusive_scan<Binop, T, 1, native>
100+
{
101+
using binop_t = Binop;
102+
using op_t = subgroup::impl::inclusive_scan<binop_t, native>;
103+
// assert T == scalar type, binop::type == T
104+
105+
T operator()(NBL_CONST_REF_ARG(T) value)
106+
{
107+
op_t op;
108+
return op(value);
109+
}
110+
};
111+
112+
template<class Binop, typename T, bool native>
113+
struct exclusive_scan<Binop, T, 1, native>
114+
{
115+
using binop_t = Binop;
116+
using op_t = subgroup::impl::exclusive_scan<binop_t, native>;
117+
118+
T operator()(NBL_CONST_REF_ARG(T) value)
119+
{
120+
op_t op;
121+
return op(value);
122+
}
123+
};
124+
125+
template<class Binop, typename T, bool native>
126+
struct reduction<Binop, T, 1, native>
127+
{
128+
using binop_t = Binop;
129+
using op_t = subgroup::impl::reduction<binop_t, native>;
130+
131+
T operator()(NBL_CONST_REF_ARG(T) value)
132+
{
133+
op_t op;
134+
return op(value);
135+
}
136+
};
137+
96138
}
97139

98140
}

0 commit comments

Comments
 (0)