Skip to content

Commit 796be61

Browse files
committed
Most changes done, ex 10 and 23 running
1 parent 52a0a8b commit 796be61

File tree

7 files changed

+61
-72
lines changed

7 files changed

+61
-72
lines changed

include/nbl/builtin/hlsl/bda/bda_accessor.hlsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@ namespace hlsl
1515
template<typename T>
1616
struct BdaAccessor
1717
{
18+
using type_t = T;
1819
static BdaAccessor<T> create(const bda::__ptr<T> ptr)
1920
{
2021
BdaAccessor<T> accessor;
2122
accessor.ptr = ptr;
2223
return accessor;
2324
}
2425

25-
T get(const uint64_t index)
26+
void get(const uint64_t index, NBL_REF_ARG(T) value)
2627
{
2728
bda::__ptr<T> target = ptr + index;
28-
return target.template deref().load();
29+
value = target.template deref().load();
2930
}
3031

3132
void set(const uint64_t index, const T value)

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,73 +16,43 @@ struct MemoryAdaptor
1616
{
1717
BaseAccessor accessor;
1818

19-
// TODO: template all get,set, atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other
20-
uint get(const uint ix) { return accessor.get(ix); }
21-
void get(const uint ix, NBL_REF_ARG(uint) value) { value = accessor.get(ix);}
22-
void get(const uint ix, NBL_REF_ARG(uint2) value) { value = uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_));}
23-
void get(const uint ix, NBL_REF_ARG(uint3) value) { value = uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_));}
24-
void get(const uint ix, NBL_REF_ARG(uint4) value) { value = uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_));}
25-
26-
void get(const uint ix, NBL_REF_ARG(int) value) { value = asint(accessor.get(ix));}
27-
void get(const uint ix, NBL_REF_ARG(int2) value) { value = asint(uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_)));}
28-
void get(const uint ix, NBL_REF_ARG(int3) value) { value = asint(uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_)));}
29-
void get(const uint ix, NBL_REF_ARG(int4) value) { value = asint(uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_)));}
19+
// TODO: template atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other
20+
uint get(const uint ix)
21+
{
22+
uint retVal;
23+
accessor.get(ix, retVal);
24+
return retVal;
25+
}
3026

31-
void get(const uint ix, NBL_REF_ARG(float) value) { value = asfloat(accessor.get(ix));}
32-
void get(const uint ix, NBL_REF_ARG(float2) value) { value = asfloat(uint2(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_)));}
33-
void get(const uint ix, NBL_REF_ARG(float3) value) { value = asfloat(uint3(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_)));}
34-
void get(const uint ix, NBL_REF_ARG(float4) value) { value = asfloat(uint4(accessor.get(ix), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_)));}
27+
template<typename Scalar>
28+
void get(const uint ix, NBL_REF_ARG(Scalar) value) { accessor.get(ix, value);}
29+
template<typename Scalar>
30+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y);}
31+
template<typename Scalar>
32+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z);}
33+
template<typename Scalar>
34+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, value.w);}
3535

36-
void set(const uint ix, const uint value) {accessor.set(ix, value);}
37-
void set(const uint ix, const uint2 value) {
36+
template<typename Scalar>
37+
void set(const uint ix, const Scalar value) {accessor.set(ix, value);}
38+
template<typename Scalar>
39+
void set(const uint ix, const vector <Scalar, 2> value) {
3840
accessor.set(ix, value.x);
3941
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y);
4042
}
41-
void set(const uint ix, const uint3 value) {
43+
template<typename Scalar>
44+
void set(const uint ix, const <Scalar, 3> value) {
4245
accessor.set(ix, value.x);
4346
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y);
4447
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z);
4548
}
46-
void set(const uint ix, const uint4 value) {
49+
template<typename Scalar>
50+
void set(const uint ix, const <Scalar, 4> value) {
4751
accessor.set(ix, value.x);
4852
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y);
4953
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z);
5054
accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, value.w);
5155
}
52-
53-
void set(const uint ix, const int value) {accessor.set(ix, asuint(value));}
54-
void set(const uint ix, const int2 value) {
55-
accessor.set(ix, asuint(value.x));
56-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
57-
}
58-
void set(const uint ix, const int3 value) {
59-
accessor.set(ix, asuint(value.x));
60-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
61-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
62-
}
63-
void set(const uint ix, const int4 value) {
64-
accessor.set(ix, asuint(value.x));
65-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
66-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
67-
accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w));
68-
}
69-
70-
void set(const uint ix, const float value) {accessor.set(ix, asuint(value));}
71-
void set(const uint ix, const float2 value) {
72-
accessor.set(ix, asuint(value.x));
73-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
74-
}
75-
void set(const uint ix, const float3 value) {
76-
accessor.set(ix, asuint(value.x));
77-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
78-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
79-
}
80-
void set(const uint ix, const float4 value) {
81-
accessor.set(ix, asuint(value.x));
82-
accessor.set(ix + _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.y));
83-
accessor.set(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.z));
84-
accessor.set(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, asuint(value.w));
85-
}
8656

8757
void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
8858
orig = accessor.atomicAnd(ix, value);

include/nbl/builtin/hlsl/sort/counting.hlsl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ template<
2222
typename ValueAccessor,
2323
typename HistogramAccessor,
2424
typename SharedAccessor,
25+
typename key_t = decltype(impl::declval < KeyAccessor > ().get(0)),
2526
bool robust=false
2627
>
2728
struct counting
2829
{
29-
using key_t = decltype(impl::declval < KeyAccessor > ().get(0));
30-
using this_t = counting<GroupSize, KeyBucketCount, KeyAccessor, ValueAccessor, HistogramAccessor, SharedAccessor>;
30+
using this_t = counting<GroupSize, KeyBucketCount, KeyAccessor, ValueAccessor, HistogramAccessor, SharedAccessor, key_t>;
3131

3232
static this_t create(const uint32_t workGroupIndex)
3333
{
@@ -58,7 +58,8 @@ struct counting
5858

5959
for (; index < endIndex; index += GroupSize)
6060
{
61-
uint32_t k = key.get(index);
61+
uint32_t k;
62+
key.get(index, k);
6263
if (robust && (k<params.minimum || k>params.maximum) )
6364
continue;
6465
sdata.atomicAdd(k - params.minimum, (uint32_t) 1);
@@ -79,7 +80,8 @@ struct counting
7980

8081
uint32_t tid = workgroup::SubgroupContiguousIndex();
8182
// because first chunk of histogram and workgroup scan scratch are aliased
82-
uint32_t histogram_value = sdata.get(tid);
83+
uint32_t histogram_value;
84+
sdata.get(tid, histogram_value);
8385

8486
sdata.workgroupExecutionAndMemoryBarrier();
8587

@@ -97,14 +99,18 @@ struct counting
9799
// no if statement about the last iteration needed
98100
if (is_last_wg_invocation)
99101
{
100-
sdata.set(keyBucketStart, sdata.get(keyBucketStart) + sum);
102+
uint32_t beforeSum;
103+
sdata.get(keyBucketStart, beforeSum);
104+
sdata.set(keyBucketStart, beforeSum + sum);
101105
}
102106

103107
// propagate last block tail to next block head and protect against subsequent scans stepping on each other's toes
104108
sdata.workgroupExecutionAndMemoryBarrier();
105109

106110
// no aliasing anymore
107-
sum = inclusive_scan(sdata.get(vid), sdata);
111+
uint32_t atVid;
112+
sdata.get(vid, atVid);
113+
sum = inclusive_scan(atVid, sdata);
108114
if (vid < KeyBucketCount) {
109115
histogram.atomicAdd(vid, sum);
110116
}
@@ -131,7 +137,8 @@ struct counting
131137
{
132138
// have to use modulo operator in case `KeyBucketCount<=2*GroupSize`, better hope KeyBucketCount is Power of Two
133139
const uint32_t shifted_tid = (vtid + shift) % KeyBucketCount;
134-
const uint32_t bucket_value = sdata.get(shifted_tid);
140+
const uint32_t bucket_value;
141+
sdata.get(shifted_tid, bucket_value);
135142
const uint32_t firstOutputIndex = histogram.atomicSub(shifted_tid, bucket_value) - bucket_value;
136143

137144
sdata.set(shifted_tid, firstOutputIndex);
@@ -145,10 +152,12 @@ struct counting
145152
[unroll]
146153
for (; index < endIndex; index += GroupSize)
147154
{
148-
const key_t k = key.get(index);
155+
key_t k;
156+
key.get(index, k);
149157
if (robust && (k<params.minimum || k>params.maximum) )
150158
continue;
151-
const uint32_t v = val.get(index);
159+
uint32_t v;
160+
val.get(index, v);
152161
const uint32_t sortedIx = sdata.atomicAdd(k - params.minimum, 1);
153162
key.set(sortedIx, k);
154163
val.set(sortedIx, v);

include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ uint16_t ballotCountedBitDWORD(NBL_REF_ARG(BallotAccessor) ballotAccessor)
8484
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
8585
if (index<DWORDCount)
8686
{
87-
uint32_t bitfield = ballotAccessor.get(index);
87+
uint32_t bitfield;
88+
ballotAccessor.get(index, bitfield);
8889
// strip unwanted bits from bitfield of the last item
8990
const uint16_t Remainder = ItemCount&31;
9091
if (Remainder!=0 && index==DWORDCount-1)
@@ -99,7 +100,8 @@ uint16_t ballotScanBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_
99100
{
100101
const uint16_t subgroupIndex = SubgroupContiguousIndex();
101102
const uint16_t bitfieldIndex = getDWORD(subgroupIndex);
102-
const uint32_t localBitfield = ballotAccessor.get(bitfieldIndex);
103+
uint32_t localBitfield;
104+
ballotAccessor.get(bitfieldIndex, localBitfield);
103105

104106
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
105107
uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount,device_capabilities>::template __call<ArithmeticAccessor>(
@@ -110,7 +112,7 @@ uint16_t ballotScanBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_
110112
if (subgroupIndex<DWORDCount)
111113
arithmeticAccessor.set(subgroupIndex,count);
112114
arithmeticAccessor.workgroupExecutionAndMemoryBarrier();
113-
count = arithmeticAccessor.get(bitfieldIndex);
115+
arithmeticAccessor.get(bitfieldIndex, count);
114116
return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[getDWORD(uint16_t(glsl::gl_SubgroupInvocationID()))])+count);
115117
}
116118
}

include/nbl/builtin/hlsl/workgroup/ballot.hlsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ template<class Accessor>
101101
bool ballotBitExtract(const uint16_t index, NBL_REF_ARG(Accessor) accessor)
102102
{
103103
assert(index<Volume());
104-
return bool(accessor.get(impl::getDWORD(index))&(1u<<(index&31u)));
104+
uint32_t dwordAtIndex;
105+
accessor.get(impl::getDWORD(index), dwordAtIndex);
106+
return bool(dwordAtIndex & (1u<<(index&31u)));
105107
}
106108

107109
/**

include/nbl/builtin/hlsl/workgroup/broadcast.hlsl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ T Broadcast(NBL_CONST_REF_ARG(T) val, NBL_REF_ARG(Accessor) accessor, const uint
2929

3030
accessor.workgroupExecutionAndMemoryBarrier();
3131

32-
return accessor.get(0);
32+
T retVal;
33+
accessor.get(0, retVal);
34+
return retVal;
3335
}
3436

3537
template<typename T, class Accessor>

include/nbl/builtin/hlsl/workgroup/shared_scan.hlsl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ struct reduce
6464
participate = SubgroupContiguousIndex() <= (lastInvocationInLevel >>= glsl::gl_SubgroupSizeLog2());
6565
if(participate)
6666
{
67-
const type_t prevLevelScan = scratchAccessor.get(scanLoadIndex);
67+
type_t prevLevelScan;
68+
scratchAccessor.get(scanLoadIndex, prevLevelScan);
6869
scan = subgroupOp(prevLevelScan);
6970
}
7071
}
@@ -121,7 +122,9 @@ struct scan// : reduce<BinOp,ItemCount> https://github.com/microsoft/DirectXShad
121122
{
122123
// this is fine if on the way up you also += under `if (participate)`
123124
scanStoreIndex -= __base.lastInvocationInLevel+1;
124-
__base.lastLevelScan = binop(__base.lastLevelScan,scratchAccessor.get(scanStoreIndex));
125+
type_t higherLevelEPS;
126+
scratchAccessor.get(scanStoreIndex, higherLevelEPS);
127+
__base.lastLevelScan = binop(__base.lastLevelScan,higherLevelEPS);
125128
}
126129
// now `lastLevelScan` has current level's inclusive prefux sum computed properly
127130
// note we're overwriting the same location with same invocation so no barrier needed
@@ -134,7 +137,7 @@ struct scan// : reduce<BinOp,ItemCount> https://github.com/microsoft/DirectXShad
134137
if (__base.participate)
135138
{
136139
// we either need to prevent OOB read altogether OR cmov identity after the far
137-
__base.lastLevelScan = scratchAccessor.get(scanStoreIndex-storeLoadIndexDiff);
140+
scratchAccessor.get(scanStoreIndex-storeLoadIndexDiff, __base.lastLevelScan);
138141
}
139142
__base.lastInvocationInLevel = lastInvocation>>logShift;
140143
}

0 commit comments

Comments
 (0)