Skip to content

Commit 3eb1826

Browse files
committed
Adding accessor changes
1 parent 6941998 commit 3eb1826

File tree

3 files changed

+47
-80
lines changed

3 files changed

+47
-80
lines changed

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 32 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ struct MemoryAdaptor
2727
template<typename Scalar>
2828
void get(const uint ix, NBL_REF_ARG(Scalar) value) { accessor.get(ix, value);}
2929
template<typename Scalar>
30-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y);}
30+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y);}
3131
template<typename Scalar>
32-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z);}
32+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y), accessor.get(ix + 2 * Stride, value.z);}
3333
template<typename Scalar>
34-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + _NBL_HLSL_WORKGROUP_SIZE_, value.y), accessor.get(ix + 2 * _NBL_HLSL_WORKGROUP_SIZE_, value.z), accessor.get(ix + 3 * _NBL_HLSL_WORKGROUP_SIZE_, value.w);}
34+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y), accessor.get(ix + 2 * Stride, value.z), accessor.get(ix + 3 * Stride, value.w);}
3535

3636
template<typename Scalar>
3737
void set(const uint ix, const Scalar value) {accessor.set(ix, value);}
@@ -41,13 +41,13 @@ struct MemoryAdaptor
4141
accessor.set(ix + Stride, value.y);
4242
}
4343
template<typename Scalar>
44-
void set(const uint ix, const <Scalar, 3> value) {
44+
void set(const uint ix, const vector <Scalar, 3> value) {
4545
accessor.set(ix, value.x);
4646
accessor.set(ix + Stride, value.y);
4747
accessor.set(ix + 2 * Stride, value.z);
4848
}
4949
template<typename Scalar>
50-
void set(const uint ix, const <Scalar, 4> value) {
50+
void set(const uint ix, const vector <Scalar, 4> value) {
5151
accessor.set(ix, value.x);
5252
accessor.set(ix + Stride, value.y);
5353
accessor.set(ix + 2 * Stride, value.z);
@@ -109,75 +109,37 @@ struct MemoryAdaptor<BaseAccessor, 0>
109109
{
110110
BaseAccessor accessor;
111111
uint32_t stride;
112-
113-
// TODO: template all get,set, atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other
114-
uint get(const uint ix) { return accessor.get(ix); }
115-
void get(const uint ix, NBL_REF_ARG(uint) value) { value = accessor.get(ix);}
116-
void get(const uint ix, NBL_REF_ARG(uint2) value) { value = uint2(accessor.get(ix), accessor.get(ix + stride));}
117-
void get(const uint ix, NBL_REF_ARG(uint3) value) { value = uint3(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride));}
118-
void get(const uint ix, NBL_REF_ARG(uint4) value) { value = uint4(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride), accessor.get(ix + 3 * stride));}
119-
120-
void get(const uint ix, NBL_REF_ARG(int) value) { value = asint(accessor.get(ix));}
121-
void get(const uint ix, NBL_REF_ARG(int2) value) { value = asint(uint2(accessor.get(ix), accessor.get(ix + stride)));}
122-
void get(const uint ix, NBL_REF_ARG(int3) value) { value = asint(uint3(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride)));}
123-
void get(const uint ix, NBL_REF_ARG(int4) value) { value = asint(uint4(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride), accessor.get(ix + 3 * stride)));}
124-
125-
void get(const uint ix, NBL_REF_ARG(float) value) { value = asfloat(accessor.get(ix));}
126-
void get(const uint ix, NBL_REF_ARG(float2) value) { value = asfloat(uint2(accessor.get(ix), accessor.get(ix + stride)));}
127-
void get(const uint ix, NBL_REF_ARG(float3) value) { value = asfloat(uint3(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride)));}
128-
void get(const uint ix, NBL_REF_ARG(float4) value) { value = asfloat(uint4(accessor.get(ix), accessor.get(ix + stride), accessor.get(ix + 2 * stride), accessor.get(ix + 3 * stride)));}
112+
113+
template<typename Scalar>
114+
void get(const uint ix, NBL_REF_ARG(Scalar) value) { accessor.get(ix, value);}
115+
template<typename Scalar>
116+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y);}
117+
template<typename Scalar>
118+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y), accessor.get(ix + 2 * stride, value.z);}
119+
template<typename Scalar>
120+
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y), accessor.get(ix + 2 * stride, value.z), accessor.get(ix + 3 * stride, value.w);}
129121

130-
void set(const uint ix, const uint value) {accessor.set(ix, value);}
131-
void set(const uint ix, const uint2 value) {
122+
template<typename Scalar>
123+
void set(const uint ix, const Scalar value) {accessor.set(ix, value);}
124+
template<typename Scalar>
125+
void set(const uint ix, const vector <Scalar, 2> value) {
132126
accessor.set(ix, value.x);
133127
accessor.set(ix + stride, value.y);
134128
}
135-
void set(const uint ix, const uint3 value) {
129+
template<typename Scalar>
130+
void set(const uint ix, const vector <Scalar, 3> value) {
136131
accessor.set(ix, value.x);
137132
accessor.set(ix + stride, value.y);
138133
accessor.set(ix + 2 * stride, value.z);
139134
}
140-
void set(const uint ix, const uint4 value) {
135+
template<typename Scalar>
136+
void set(const uint ix, const vector <Scalar, 4> value) {
141137
accessor.set(ix, value.x);
142138
accessor.set(ix + stride, value.y);
143139
accessor.set(ix + 2 * stride, value.z);
144140
accessor.set(ix + 3 * stride, value.w);
145141
}
146142

147-
void set(const uint ix, const int value) {accessor.set(ix, asuint(value));}
148-
void set(const uint ix, const int2 value) {
149-
accessor.set(ix, asuint(value.x));
150-
accessor.set(ix + stride, asuint(value.y));
151-
}
152-
void set(const uint ix, const int3 value) {
153-
accessor.set(ix, asuint(value.x));
154-
accessor.set(ix + stride, asuint(value.y));
155-
accessor.set(ix + 2 * stride, asuint(value.z));
156-
}
157-
void set(const uint ix, const int4 value) {
158-
accessor.set(ix, asuint(value.x));
159-
accessor.set(ix + stride, asuint(value.y));
160-
accessor.set(ix + 2 * stride, asuint(value.z));
161-
accessor.set(ix + 3 * stride, asuint(value.w));
162-
}
163-
164-
void set(const uint ix, const float value) {accessor.set(ix, asuint(value));}
165-
void set(const uint ix, const float2 value) {
166-
accessor.set(ix, asuint(value.x));
167-
accessor.set(ix + stride, asuint(value.y));
168-
}
169-
void set(const uint ix, const float3 value) {
170-
accessor.set(ix, asuint(value.x));
171-
accessor.set(ix + stride, asuint(value.y));
172-
accessor.set(ix + 2 * stride, asuint(value.z));
173-
}
174-
void set(const uint ix, const float4 value) {
175-
accessor.set(ix, asuint(value.x));
176-
accessor.set(ix + stride, asuint(value.y));
177-
accessor.set(ix + 2 * stride, asuint(value.z));
178-
accessor.set(ix + 3 * stride, asuint(value.w));
179-
}
180-
181143
void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
182144
orig = accessor.atomicAnd(ix, value);
183145
}
@@ -229,29 +191,33 @@ struct MemoryAdaptor<BaseAccessor, 0>
229191

230192
// ---------------------------------------------- Offset Accessor ----------------------------------------------------
231193

232-
template<class BaseAccessor, class AccessorType, uint32_t Offset>
194+
template<class BaseAccessor, uint32_t Offset>
233195
struct OffsetAccessor
234196
{
235197
BaseAccessor accessor;
236198

237-
void set(uint32_t idx, NBL_REF_ARG(AccessorType) x) {accessor.set(idx + Offset, x);}
199+
template <typename T>
200+
void set(uint32_t idx, T value) {accessor.set(idx + Offset, value);}
238201

239-
AccessorType get(uint32_t idx) {return accessor.get(idx + Offset);}
202+
template <typename T>
203+
void get(uint32_t idx, NBL_REF_ARG(T) value) {accessor.get(idx + Offset, value);}
240204

241205
// TODO: figure out the `enable_if` syntax for this
242206
void workgroupExecutionAndMemoryBarrier() {accessor.workgroupExecutionAndMemoryBarrier();}
243207
};
244208

245209
// Dynamic offset version
246-
template<class BaseAccessor, class AccessorType>
210+
template<class BaseAccessor>
247211
struct DynamicOffsetAccessor
248212
{
249213
BaseAccessor accessor;
250214
uint32_t offset;
251215

252-
void set(uint32_t idx, NBL_REF_ARG(AccessorType) x) {accessor.set(idx + offset, x);}
216+
template <typename T>
217+
void set(uint32_t idx, T value) {accessor.set(idx + offset, value);}
253218

254-
AccessorType get(uint32_t idx) {return accessor.get(idx + offset);}
219+
template <typename T>
220+
void get(uint32_t idx, NBL_REF_ARG(T) value) {accessor.get(idx + offset, value);}
255221

256222
// TODO: figure out the `enable_if` syntax for this
257223
void workgroupExecutionAndMemoryBarrier() {accessor.workgroupExecutionAndMemoryBarrier();}

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ struct FFT<2,false, Scalar, device_capabilities>
7070
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
7171

7272
// Read lo, hi values from global memory
73-
complex_t<Scalar> lo = accessor.get(loIx);
74-
complex_t<Scalar> hi = accessor.get(hiIx);
73+
complex_t<Scalar> lo, hi;
74+
accessor.get(loIx, lo);
75+
accessor.get(hiIx, hi);
7576

7677
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
7778
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
@@ -124,8 +125,9 @@ struct FFT<2,true, Scalar, device_capabilities>
124125
const uint32_t hiIx = _NBL_HLSL_WORKGROUP_SIZE_ | loIx;
125126

126127
// Read lo, hi values from global memory
127-
complex_t<Scalar> lo = accessor.get(loIx);
128-
complex_t<Scalar> hi = accessor.get(hiIx);
128+
complex_t<Scalar> lo, hi;
129+
accessor.get(loIx, lo);
130+
accessor.get(hiIx, hi);
129131

130132
// Run a subgroup-sized FFT, then continue with bigger steps
131133
subgroup::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
@@ -157,9 +159,6 @@ struct FFT<2,true, Scalar, device_capabilities>
157159
}
158160
};
159161

160-
161-
// ---------------------------- Below pending --------------------------------------------------
162-
163162
// Forward FFT
164163
template<uint32_t K, typename Scalar, class device_capabilities>
165164
struct FFT<K, false, Scalar, device_capabilities>
@@ -175,8 +174,9 @@ struct FFT<K, false, Scalar, device_capabilities>
175174
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
176175
const uint32_t hiIx = loIx | stride;
177176

178-
complex_t<Scalar> lo = accessor.get(loIx);
179-
complex_t<Scalar> hi = accessor.get(hiIx);
177+
complex_t<Scalar> lo, hi;
178+
accessor.get(loIx, lo);
179+
accessor.get(hiIx, hi);
180180

181181
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false,Scalar>(virtualThreadID & (stride - 1), stride),lo,hi);
182182

@@ -187,7 +187,7 @@ struct FFT<K, false, Scalar, device_capabilities>
187187
}
188188

189189
// do K/2 small workgroup FFTs
190-
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
190+
DynamicOffsetAccessor <Accessor> offsetAccessor;
191191
//[unroll(K/2)]
192192
for (uint32_t k = 0; k < K; k += 2)
193193
{
@@ -208,7 +208,7 @@ struct FFT<K, true, Scalar, device_capabilities>
208208
static enable_if_t< (mpl::is_pot_v<K> && K > 2), void > __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
209209
{
210210
// do K/2 small workgroup FFTs
211-
DynamicOffsetAccessor < Accessor, complex_t<Scalar> > offsetAccessor;
211+
DynamicOffsetAccessor <Accessor> offsetAccessor;
212212
//[unroll(K/2)]
213213
for (uint32_t k = 0; k < K; k += 2)
214214
{
@@ -228,8 +228,9 @@ struct FFT<K, true, Scalar, device_capabilities>
228228
const uint32_t loIx = ((virtualThreadID & (~(stride - 1))) << 1) | (virtualThreadID & (stride - 1));
229229
const uint32_t hiIx = loIx | stride;
230230

231-
complex_t<Scalar> lo = accessor.get(loIx);
232-
complex_t<Scalar> hi = accessor.get(hiIx);
231+
complex_t<Scalar> lo, hi;
232+
accessor.get(loIx, lo);
233+
accessor.get(hiIx, hi);
233234

234235
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true,Scalar>(virtualThreadID & (stride - 1), stride), lo,hi);
235236

include/nbl/builtin/hlsl/workgroup/shuffle.hlsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct shuffleXor
2424
// Wait until all writes are done before reading
2525
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
2626

27-
value = sharedmemAccessor.get(threadID ^ mask);
27+
sharedmemAccessor.get(threadID ^ mask, value);
2828
}
2929

3030
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)

0 commit comments

Comments
 (0)