Skip to content

Commit 686fcec

Browse files
committed
Move Accessor code to userspace
1 parent 12a3e80 commit 686fcec

File tree

2 files changed

+94
-9
lines changed

2 files changed

+94
-9
lines changed
Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,39 @@
11
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
22
#include "nbl/builtin/hlsl/sort/counting.hlsl"
3+
#include "nbl/builtin/hlsl/sort/common.hlsl"
34

45
[[vk::push_constant]] nbl::hlsl::sort::CountingPushData pushData;
56

7+
struct PtrAccessor
8+
{
9+
static PtrAccessor createAccessor(uint64_t addr)
10+
{
11+
PtrAccessor ptr;
12+
ptr.addr = addr;
13+
return ptr;
14+
}
15+
16+
uint32_t get(uint64_t index)
17+
{
18+
return bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
19+
deref().load();
20+
}
21+
22+
void set(uint64_t index, uint32_t value)
23+
{
24+
bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
25+
deref().store(value);
26+
}
27+
28+
uint32_t atomicAdd(uint64_t index, uint32_t value)
29+
{
30+
return bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
31+
deref().atomicAdd(value);
32+
}
33+
34+
uint64_t addr;
35+
};
36+
637
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
738
{
839
return uint32_t3(WorkgroupSize, 1, 1);
@@ -11,10 +42,10 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
1142
[numthreads(WorkgroupSize,1,1)]
1243
void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
1344
{
14-
nbl::hlsl::sort::counting < bda::PtrAccessor<uint32_t>, bda::PtrAccessor<uint32_t>, bda::PtrAccessor<uint32_t> > counter;
45+
nbl::hlsl::sort::counting <PtrAccessor, PtrAccessor, PtrAccessor> counter;
1546
counter.histogram(
16-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.inputKeyAddress),
17-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.scratchAddress),
47+
PtrAccessor::createAccessor(pushData.inputKeyAddress),
48+
PtrAccessor::createAccessor(pushData.scratchAddress),
1849
pushData
1950
);
2051
}
Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,64 @@
11
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"
22
#include "nbl/builtin/hlsl/sort/counting.hlsl"
3+
#include "nbl/builtin/hlsl/sort/common.hlsl"
34

45
[[vk::push_constant]] nbl::hlsl::sort::CountingPushData pushData;
56

7+
struct PtrAccessor
8+
{
9+
static PtrAccessor createAccessor(uint64_t addr)
10+
{
11+
PtrAccessor ptr;
12+
ptr.addr = addr;
13+
return ptr;
14+
}
15+
16+
uint32_t get(uint64_t index)
17+
{
18+
return bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
19+
deref().load();
20+
}
21+
22+
void set(uint64_t index, uint32_t value)
23+
{
24+
bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
25+
deref().store(value);
26+
}
27+
28+
uint32_t atomicAdd(uint64_t index, uint32_t value)
29+
{
30+
return bda::__ptr < uint32_t > (addr + sizeof(uint32_t) * index).template
31+
deref().atomicAdd(value);
32+
}
33+
34+
uint64_t addr;
35+
};
36+
37+
struct DoublePtrAccessor
38+
{
39+
static DoublePtrAccessor createAccessor(uint64_t in_addr, uint64_t out_addr)
40+
{
41+
DoublePtrAccessor ptr;
42+
ptr.in_addr = in_addr;
43+
ptr.out_addr = out_addr;
44+
return ptr;
45+
}
46+
47+
uint32_t get(uint64_t index)
48+
{
49+
return bda::__ptr < uint32_t > (in_addr + sizeof(uint32_t) * index).template
50+
deref().load();
51+
}
52+
53+
void set(uint64_t index, uint32_t value)
54+
{
55+
bda::__ptr < uint32_t > (out_addr + sizeof(uint32_t) * index).template
56+
deref().store(value);
57+
}
58+
59+
uint64_t in_addr, out_addr;
60+
};
61+
662
uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
763
{
864
return uint32_t3(WorkgroupSize, 1, 1);
@@ -11,13 +67,11 @@ uint32_t3 nbl::hlsl::glsl::gl_WorkGroupSize()
1167
[numthreads(WorkgroupSize, 1, 1)]
1268
void main(uint32_t3 ID : SV_GroupThreadID, uint32_t3 GroupID : SV_GroupID)
1369
{
14-
nbl::hlsl::sort::counting < bda::PtrAccessor<uint32_t>, bda::PtrAccessor<uint32_t>, bda::PtrAccessor<uint32_t> > counter;
70+
nbl::hlsl::sort::counting <DoublePtrAccessor, DoublePtrAccessor, PtrAccessor> counter;
1571
counter.scatter(
16-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.inputKeyAddress),
17-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.inputValueAddress),
18-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.scratchAddress),
19-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.outputKeyAddress),
20-
bda::PtrAccessor<uint32_t>::createAccessor(pushData.outputValueAddress),
72+
DoublePtrAccessor::createAccessor(pushData.inputKeyAddress, pushData.outputKeyAddress),
73+
DoublePtrAccessor::createAccessor(pushData.inputValueAddress, pushData.outputValueAddress),
74+
PtrAccessor::createAccessor(pushData.scratchAddress),
2175
pushData
2276
);
2377
}

0 commit comments

Comments
 (0)