Skip to content

Commit 10d9c39

Browse files
committed
subgroup2 implementations
1 parent d4e3738 commit 10d9c39

File tree

4 files changed

+144
-48
lines changed

4 files changed

+144
-48
lines changed

include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,6 @@ struct inclusive_scan : impl::inclusive_scan<Binop,device_capabilities_traits<de
2626
template<class Binop, class device_capabilities=void>
2727
struct exclusive_scan : impl::exclusive_scan<Binop,device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic> {};
2828

29-
}
30-
31-
namespace subgroup2
32-
{
33-
34-
template<typename Config, class BinOp, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
35-
struct ArithmeticParams
36-
{
37-
using config_t = Config;
38-
using binop_t = BinOp;
39-
using type_t = typename BinOp::type_t;
40-
41-
NBL_CONSTEXPR_STATIC_INLINE int32_t itemsPerInvocation = ItemsPerInvocation;
42-
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
43-
};
44-
45-
template<typename Params>
46-
struct reduction : impl::reduction<typename Params::binop_t,Params::UseNativeIntrinsics> {};
47-
template<typename Params>
48-
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
49-
template<typename Params>
50-
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,Params::UseNativeIntrinsics> {};
51-
5229
}
5330
}
5431
}

include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -140,31 +140,6 @@ struct reduction<Binop, false>
140140
}
141141

142142
}
143-
144-
namespace subgroup2
145-
{
146-
147-
namespace impl
148-
{
149-
150-
template<class Binop, int32_t ItemsPerInvocation, bool native>
151-
struct reduction
152-
{
153-
using type_t = typename Binop::type_t;
154-
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
155-
using op_t = subgroup::impl::reduction<Binop, native>;
156-
157-
type_t operator()(NBL_CONST_REF_ARG(type_t) value)
158-
{
159-
// take the last subgroup invocation's value for the reduction
160-
return BroadcastLast<type_t>(inclusive_scan<Binop,false>::__call(value));
161-
}
162-
};
163-
164-
}
165-
166-
}
167-
168143
}
169144
}
170145

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_INCLUDED_
6+
7+
8+
#include "nbl/builtin/hlsl/device_capabilities_traits.hlsl"
9+
10+
#include "nbl/builtin/hlsl/subgroup/basic.hlsl"
11+
#include "nbl/builtin/hlsl/subgroup2/arithmetic_portability_impl.hlsl"
12+
#include "nbl/builtin/hlsl/concepts.hlsl"
13+
14+
15+
namespace nbl
16+
{
17+
namespace hlsl
18+
{
19+
namespace subgroup2
20+
{
21+
22+
template<typename Config,template<class> class BinOp, typename T, int32_t ItemsPerInvocation=1, class device_capabilities=void NBL_PRIMARY_REQUIRES(is_configuration_v<Config>)
23+
struct ArithmeticParams
24+
{
25+
using config_t = Config;
26+
using binop_t = BinOp;
27+
using type_t = T;
28+
29+
NBL_CONSTEXPR_STATIC_INLINE int32_t itemsPerInvocation = ItemsPerInvocation;
30+
NBL_CONSTEXPR_STATIC_INLINE bool UseNativeIntrinsics = device_capabilities_traits<device_capabilities>::shaderSubgroupArithmetic /*&& /*some heuristic for when its faster*/;
31+
};
32+
33+
template<typename Params>
34+
struct reduction : impl::reduction<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
35+
template<typename Params>
36+
struct inclusive_scan : impl::inclusive_scan<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
37+
template<typename Params>
38+
struct exclusive_scan : impl::exclusive_scan<typename Params::binop_t,typename Params::type_t,Params::itemsPerInvocation,Params::UseNativeIntrinsics> {};
39+
40+
}
41+
}
42+
}
43+
44+
#endif
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
2+
// This file is part of the "Nabla Engine".
3+
// For conditions of distribution and use, see copyright notice in nabla.h
4+
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
5+
#define _NBL_BUILTIN_HLSL_SUBGROUP2_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
6+
7+
#include "nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl"
8+
9+
namespace nbl
10+
{
11+
namespace hlsl
12+
{
13+
namespace subgroup2
14+
{
15+
16+
namespace impl
17+
{
18+
19+
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
20+
struct inclusive_scan
21+
{
22+
using type_t = typename T;
23+
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
24+
using binop_t = Binop<type_t>;
25+
using binop_par_t = Binop<par_type_t>;
26+
using exclusive_scan_op_t = subgroup::impl::exclusive_scan<binop_t, native>;
27+
28+
par_type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
29+
{
30+
binop_t binop;
31+
par_type_t retval;
32+
retval[0] = value[0];
33+
[unroll(ItemsPerInvocation-1)]
34+
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
35+
retval[i] = binop(retval[i-1], value[i]);
36+
37+
exclusive_scan_op_t op;
38+
type_t exclusive = op(retval[ItemsPerInvocation-1]);
39+
40+
[unroll(ItemsPerInvocation)]
41+
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
42+
retval[i] = binop(retval[i], exclusive);
43+
return retval;
44+
}
45+
};
46+
47+
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
48+
struct exclusive_scan
49+
{
50+
using type_t = typename T;
51+
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
52+
using binop_t = Binop<type_t>;
53+
using binop_par_t = Binop<par_type_t>;
54+
using inclusive_scan_op_t = subgroup2::impl::inclusive_scan<binop_par_t, native>;
55+
56+
par_type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
57+
{
58+
inclusive_scan_op_t op;
59+
value = op(value);
60+
61+
par_type_t left = glsl::subgroupShuffleUp<par_type_t>(value,1);
62+
63+
par_type_t retval;
64+
[unroll(ItemsPerInvocation-1)]
65+
for (uint32_t i = 1; i < ItemsPerInvocation; i++)
66+
retval[ItemsPerInvocation-i] = retval[ItemsPerInvocation-i-1];
67+
retval[0] = bool(glsl::gl_SubgroupInvocationID()) ? left[ItemsPerInvocation-1] : binop_t::identity;
68+
return retval;
69+
}
70+
};
71+
72+
template<template<class> class Binop, typename T, int32_t ItemsPerInvocation, bool native>
73+
struct reduction
74+
{
75+
using type_t = typename T;
76+
using par_type_t = conditional_t<ItemsPerInvocation < 2, type_t, vector<type_t, ItemsPerInvocation> >;
77+
using binop_t = Binop<type_t>;
78+
using binop_par_t = Binop<par_type_t>;
79+
using op_t = subgroup::impl::reduction<binop_par_t, native>;
80+
81+
type_t operator()(NBL_CONST_REF_ARG(par_type_t) value)
82+
{
83+
binop_t binop;
84+
op_t op;
85+
par_type_t result = op(value);
86+
type_t retval;
87+
[unroll(ItemsPerInvocation-1)]
88+
for (uint32_t i = 0; i < ItemsPerInvocation; i++)
89+
retval += binop(retval, result[i]);
90+
return retval;
91+
}
92+
};
93+
94+
}
95+
96+
}
97+
}
98+
}
99+
100+
#endif

0 commit comments

Comments
 (0)