1
1
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
2
2
#include "nbl/builtin/hlsl/sort/counting.hlsl"
3
+ #include "nbl/builtin/hlsl/sort/common.hlsl"
3
4
4
5
[[vk::push_constant]] nbl::hlsl::sort::CountingPushData pushData;
5
6
7
+ struct PtrAccessor
8
+ {
9
+ static PtrAccessor createAccessor (uint64_t addr)
10
+ {
11
+ PtrAccessor ptr;
12
+ ptr.addr = addr;
13
+ return ptr;
14
+ }
15
+
16
+ uint32_t get (uint64_t index)
17
+ {
18
+ return bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
19
+ deref ().load ();
20
+ }
21
+
22
+ void set (uint64_t index, uint32_t value)
23
+ {
24
+ bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
25
+ deref ().store (value);
26
+ }
27
+
28
+ uint32_t atomicAdd (uint64_t index, uint32_t value)
29
+ {
30
+ return bda::__ptr < uint32_t > (addr + sizeof (uint32_t) * index).template
31
+ deref ().atomicAdd (value);
32
+ }
33
+
34
+ uint64_t addr;
35
+ };
36
+
37
+ struct DoublePtrAccessor
38
+ {
39
+ static DoublePtrAccessor createAccessor (uint64_t in_addr, uint64_t out_addr)
40
+ {
41
+ DoublePtrAccessor ptr;
42
+ ptr.in_addr = in_addr;
43
+ ptr.out_addr = out_addr;
44
+ return ptr;
45
+ }
46
+
47
+ uint32_t get (uint64_t index)
48
+ {
49
+ return bda::__ptr < uint32_t > (in_addr + sizeof (uint32_t) * index).template
50
+ deref ().load ();
51
+ }
52
+
53
+ void set (uint64_t index, uint32_t value)
54
+ {
55
+ bda::__ptr < uint32_t > (out_addr + sizeof (uint32_t) * index).template
56
+ deref ().store (value);
57
+ }
58
+
59
+ uint64_t in_addr, out_addr;
60
+ };
61
+
6
62
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize ()
7
63
{
8
64
return uint32_t3 (WorkgroupSize, 1 , 1 );
@@ -11,13 +67,11 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
11
67
[numthreads (WorkgroupSize, 1 , 1 )]
12
68
void main (uint32_t3 ID : SV_GroupThreadID , uint32_t3 GroupID : SV_GroupID )
13
69
{
14
- nbl::hlsl::sort::counting < bda::PtrAccessor<uint32_t>, bda::PtrAccessor<uint32_t>, bda:: PtrAccessor<uint32_t> > counter;
70
+ nbl::hlsl::sort::counting <DoublePtrAccessor, DoublePtrAccessor, PtrAccessor> counter;
15
71
counter.scatter (
16
- bda::PtrAccessor<uint32_t>::createAccessor (pushData.inputKeyAddress),
17
- bda::PtrAccessor<uint32_t>::createAccessor (pushData.inputValueAddress),
18
- bda::PtrAccessor<uint32_t>::createAccessor (pushData.scratchAddress),
19
- bda::PtrAccessor<uint32_t>::createAccessor (pushData.outputKeyAddress),
20
- bda::PtrAccessor<uint32_t>::createAccessor (pushData.outputValueAddress),
72
+ DoublePtrAccessor::createAccessor (pushData.inputKeyAddress, pushData.outputKeyAddress),
73
+ DoublePtrAccessor::createAccessor (pushData.inputValueAddress, pushData.outputValueAddress),
74
+ PtrAccessor::createAccessor (pushData.scratchAddress),
21
75
pushData
22
76
);
23
77
}
0 commit comments