Skip to content

Commit 4e4f26e

Browse files
committed
added workgroup accessor concepts, refactor accessor usage
1 parent 55d89c5 commit 4e4f26e

File tree

5 files changed

+104
-37
lines changed

5 files changed

+104
-37
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#ifndef _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
2+
#define _NBL_BUILTIN_HLSL_CONCEPTS_ACCESSORS_WORKGROUP_ARITHMETIC_INCLUDED_
3+
4+
#include "nbl/builtin/hlsl/concepts.hlsl"
5+
6+
namespace nbl
7+
{
8+
namespace hlsl
9+
{
10+
namespace workgroup2
11+
{
12+
13+
#define NBL_CONCEPT_NAME ArithmeticSharedMemoryAccessor
14+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
15+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
16+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
17+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
18+
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
19+
NBL_CONCEPT_BEGIN(3)
20+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
21+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
22+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
23+
NBL_CONCEPT_END(
24+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t>(index, val)), is_same_v, void))
25+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t>(index, val)), is_same_v, void))
26+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
27+
);
28+
#undef val
29+
#undef index
30+
#undef accessor
31+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
32+
33+
#define NBL_CONCEPT_NAME ArithmeticDataAccessor
34+
#define NBL_CONCEPT_TPLT_PRM_KINDS (typename)
35+
#define NBL_CONCEPT_TPLT_PRM_NAMES (T)
36+
#define NBL_CONCEPT_PARAM_0 (accessor, T)
37+
#define NBL_CONCEPT_PARAM_1 (index, uint32_t)
38+
#define NBL_CONCEPT_PARAM_2 (val, uint32_t)
39+
NBL_CONCEPT_BEGIN(3)
40+
#define accessor NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_0
41+
#define index NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_1
42+
#define val NBL_CONCEPT_PARAM_T NBL_CONCEPT_PARAM_2
43+
NBL_CONCEPT_END(
44+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template set<uint32_t>(index, val)), is_same_v, void))
45+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.template get<uint32_t>(index, val)), is_same_v, void))
46+
((NBL_CONCEPT_REQ_EXPR_RET_TYPE)((accessor.workgroupExecutionAndMemoryBarrier()), is_same_v, void))
47+
);
48+
#undef val
49+
#undef index
50+
#undef accessor
51+
#include <nbl/builtin/hlsl/concepts/__end.hlsl>
52+
53+
}
54+
}
55+
}
56+
57+
#endif

include/nbl/builtin/hlsl/workgroup2/arithmetic.hlsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "nbl/builtin/hlsl/functional.hlsl"
99
#include "nbl/builtin/hlsl/workgroup/ballot.hlsl"
1010
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
11+
#include "nbl/builtin/hlsl/concepts/accessors/workgroup_arithmetic.hlsl"
1112
#include "nbl/builtin/hlsl/workgroup2/shared_scan.hlsl"
1213

1314

@@ -21,7 +22,7 @@ namespace workgroup2
2122
template<class Config, class BinOp, class device_capabilities=void>
2223
struct reduction
2324
{
24-
template<class DataAccessor, class ScratchAccessor>
25+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
2526
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
2627
{
2728
impl::reduce<Config,BinOp,Config::LevelCount,device_capabilities> fn;
@@ -32,7 +33,7 @@ struct reduction
3233
template<class Config, class BinOp, class device_capabilities=void>
3334
struct inclusive_scan
3435
{
35-
template<class DataAccessor, class ScratchAccessor>
36+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
3637
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
3738
{
3839
impl::scan<Config,BinOp,false,Config::LevelCount,device_capabilities> fn;
@@ -43,7 +44,7 @@ struct inclusive_scan
4344
template<class Config, class BinOp, class device_capabilities=void>
4445
struct exclusive_scan
4546
{
46-
template<class DataAccessor, class ScratchAccessor>
47+
template<class DataAccessor, class ScratchAccessor NBL_FUNC_REQUIRES(ArithmeticDataAccessor<DataAccessor> && ArithmeticSharedMemoryAccessor<ScratchAccessor>)
4748
static void __call(NBL_REF_ARG(DataAccessor) dataAccessor, NBL_REF_ARG(ScratchAccessor) scratchAccessor)
4849
{
4950
impl::scan<Config,BinOp,true,Config::LevelCount,device_capabilities> fn;

0 commit comments

Comments
 (0)