Skip to content

Commit 48a7d16

Browse files
committed
changes to arithmetic accessor concepts
1 parent f07329e commit 48a7d16

File tree

3 files changed

+19
-33
lines changed

3 files changed

+19
-33
lines changed

include/nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
33

4-
#include "nbl/builtin/hlsl/concepts.hlsl"
4+
#include "nbl/builtin/hlsl/concepts/accessors/generic_shared_data.hlsl"
55

66
namespace nbl
77
{
@@ -10,46 +10,30 @@ namespace hlsl
1010
namespace workgroup2
1111
{
1212

13-
#define NBL_CONCEPT_NAME ArithmeticSharedMemoryAccessor
14-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
15-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
16-
#define NBL_CONCEPT_PARAM_0 (accessor, T)
17-
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
18-
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
19-
NBL_CONCEPT_BEGIN(3)
20-
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
21-
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
22-
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
23-
NBL_CONCEPT_END(
24-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t>(index, val)), is_same_v, void))
25-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t>(index, val)), is_same_v, void))
26-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
27-
);
28-
#undef val
29-
#undef index
30-
#undef accessor
31-
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
13+
template<typename T, typename V, typename I>
14+
NBL_BOOL_CONCEPT ArithmeticSharedMemoryAccessor = concepts::accessors::GenericSharedMemoryAccessor<T,V,I>;
3215

33-
#define NBL_CONCEPT_NAME ArithmeticDataAccessor
34-
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
35-
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
16+
#define NBL_CONCEPT_NAME ArithmeticReadOnlyDataAccessor
17+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)(typename)
18+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)(V)
3619
#define NBL_CONCEPT_PARAM_0 (accessor, T)
3720
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
38-
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
21+
#define NBL_CONCEPT_PARAM_2 (val, V)
3922
NBL_CONCEPT_BEGIN(3)
4023
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
4124
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
4225
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
4326
NBL_CONCEPT_END(
44-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t>(index, val)), is_same_v, void))
45-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t>(index, val)), is_same_v, void))
46-
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
27+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<V>(index, val)), is_same_v, void))
4728
);
4829
#undef val
4930
#undef index
5031
#undef accessor
5132
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
5233

34+
template<typename T, typename V, typename I=uint32_t>
35+
NBL_BOOL_CONCEPT ArithmeticDataAccessor = concepts::accessors::GenericDataAccessor<T,V,I>;
36+
5337
}
5438
}
5539
}

include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77

88
#include "nbl/builtin/hlsl/functional.hlsl"
9-
#include "nbl/builtin/hlsl/workgroup/ballot.hlsl"
10-
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
119
#include "nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl"
1210
#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl"
1311

@@ -24,7 +22,7 @@ struct reduction
2422
{
2523
using scalar_t = typename BinOp::type_t;
2624

27-
template<class ReadOnlyDataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<ReadOnlyDataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
25+
template<class ReadOnlyDataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticReadOnlyDataAccessor<ReadOnlyDataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t,scalar_t>)
2826
static scalar_t __call(NBL_REF_ARG(ReadOnlyDataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
2927
{
3028
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
@@ -35,7 +33,9 @@ struct reduction
3533
template<class Config, class BinOp, class device_capabilities=void>
3634
struct inclusive_scan
3735
{
38-
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
36+
using scalar_t = typename BinOp::type_t;
37+
38+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t,scalar_t>)
3939
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
4040
{
4141
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn;
@@ -46,7 +46,9 @@ struct inclusive_scan
4646
template<class Config, class BinOp, class device_capabilities=void>
4747
struct exclusive_scan
4848
{
49-
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
49+
using scalar_t = typename BinOp::type_t;
50+
51+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor,scalar_t> && ArithmeticSharedMemoryAccessor<ScratchAccessor,scalar_t,scalar_t>)
5052
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
5153
{
5254
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn;

0 commit comments

Comments
 (0)