Skip to content

Commit 2f1d3de

Browse files
committed
Merged accessor fix
2 parents 250e304 + 73a10bf commit 2f1d3de

File tree

3 files changed

+101
-35
lines changed

3 files changed

+101
-35
lines changed

examples_tests

include/nbl/builtin/hlsl/bda/bda_accessor.hlsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ struct BdaAccessor
2323
return accessor;
2424
}
2525

26+
T get(const uint64_t index)
27+
{
28+
bda::__ptr<T> target = ptr + index;
29+
return target.template deref().load();
30+
}
31+
2632
void get(const uint64_t index, NBL_REF_ARG(T) value)
2733
{
2834
bda::__ptr<T> target = ptr + index;

include/nbl/builtin/hlsl/memory_accessor.hlsl

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,59 @@ struct MemoryAdaptor
2525
}
2626

2727
template<typename Scalar>
28-
void get(const uint ix, NBL_REF_ARG(Scalar) value) { accessor.get(ix, value);}
28+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(Scalar) value)
29+
{
30+
uint32_t aux;
31+
accessor.get(ix, aux);
32+
value = bit_cast<Scalar, uint32_t>(aux);
33+
}
2934
template<typename Scalar>
30-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y);}
35+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value)
36+
{
37+
uint32_t2 aux;
38+
accessor.get(ix, aux.x);
39+
accessor.get(ix + Stride, aux.y);
40+
value = bit_cast<vector<Scalar, 2>, uint32_t2>(aux);
41+
}
3142
template<typename Scalar>
32-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y), accessor.get(ix + 2 * Stride, value.z);}
43+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value)
44+
{
45+
uint32_t3 aux;
46+
accessor.get(ix, aux.x);
47+
accessor.get(ix + Stride, aux.y);
48+
accessor.get(ix + 2 * Stride, aux.z);
49+
value = bit_cast<vector<Scalar, 3>, uint32_t3>(aux);
50+
}
3351
template<typename Scalar>
34-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + Stride, value.y), accessor.get(ix + 2 * Stride, value.z), accessor.get(ix + 3 * Stride, value.w);}
52+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value)
53+
{
54+
uint32_t4 aux;
55+
accessor.get(ix, aux.x);
56+
accessor.get(ix + Stride, aux.y);
57+
accessor.get(ix + 2 * Stride, aux.z);
58+
accessor.get(ix + 3 * Stride, aux.w);
59+
value = bit_cast<vector<Scalar, 3>, uint32_t4>(aux);
60+
}
3561

3662
template<typename Scalar>
37-
void set(const uint ix, const Scalar value) {accessor.set(ix, value);}
63+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const Scalar value) {accessor.set(ix, asuint(value));}
3864
template<typename Scalar>
39-
void set(const uint ix, const vector <Scalar, 2> value) {
40-
accessor.set(ix, value.x);
41-
accessor.set(ix + Stride, value.y);
65+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 2> value) {
66+
accessor.set(ix, asuint(value.x));
67+
accessor.set(ix + Stride, asuint(value.y));
4268
}
4369
template<typename Scalar>
44-
void set(const uint ix, const vector <Scalar, 3> value) {
45-
accessor.set(ix, value.x);
46-
accessor.set(ix + Stride, value.y);
47-
accessor.set(ix + 2 * Stride, value.z);
70+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const <Scalar, 3> value) {
71+
accessor.set(ix, asuint(value.x));
72+
accessor.set(ix + Stride, asuint(value.y));
73+
accessor.set(ix + 2 * Stride, asuint(value.z));
4874
}
4975
template<typename Scalar>
50-
void set(const uint ix, const vector <Scalar, 4> value) {
51-
accessor.set(ix, value.x);
52-
accessor.set(ix + Stride, value.y);
53-
accessor.set(ix + 2 * Stride, value.z);
54-
accessor.set(ix + 3 * Stride, value.w);
76+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const <Scalar, 4> value) {
77+
accessor.set(ix, asuint(value.x));
78+
accessor.set(ix + Stride, asuint(value.y));
79+
accessor.set(ix + 2 * Stride, asuint(value.z));
80+
accessor.set(ix + 3 * Stride, asuint(value.w));
5581
}
5682

5783
void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {
@@ -110,34 +136,68 @@ struct MemoryAdaptor<BaseAccessor, 0>
110136
BaseAccessor accessor;
111137
uint32_t stride;
112138

139+
// TODO: template atomic... then add static_asserts of `has_method<BaseAccessor,signature>::value`, do vectors and matrices in terms of each other
140+
uint get(const uint ix)
141+
{
142+
uint retVal;
143+
accessor.get(ix, retVal);
144+
return retVal;
145+
}
146+
113147
template<typename Scalar>
114-
void get(const uint ix, NBL_REF_ARG(Scalar) value) { accessor.get(ix, value);}
148+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(Scalar) value)
149+
{
150+
uint32_t aux;
151+
accessor.get(ix, aux);
152+
value = bit_cast<Scalar, uint32_t>(aux);
153+
}
115154
template<typename Scalar>
116-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y);}
155+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 2>) value)
156+
{
157+
uint32_t2 aux;
158+
accessor.get(ix, aux.x);
159+
accessor.get(ix + stride, aux.y);
160+
value = bit_cast<vector<Scalar, 2>, uint32_t2>(aux);
161+
}
117162
template<typename Scalar>
118-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y), accessor.get(ix + 2 * stride, value.z);}
163+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 3>) value)
164+
{
165+
uint32_t3 aux;
166+
accessor.get(ix, aux.x);
167+
accessor.get(ix + stride, aux.y);
168+
accessor.get(ix + 2 * stride, aux.z);
169+
value = bit_cast<vector<Scalar, 3>, uint32_t3>(aux);
170+
}
119171
template<typename Scalar>
120-
void get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value) { accessor.get(ix, value.x), accessor.get(ix + stride, value.y), accessor.get(ix + 2 * stride, value.z), accessor.get(ix + 3 * stride, value.w);}
172+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> get(const uint ix, NBL_REF_ARG(vector <Scalar, 4>) value)
173+
{
174+
uint32_t4 aux;
175+
accessor.get(ix, aux.x);
176+
accessor.get(ix + stride, aux.y);
177+
accessor.get(ix + 2 * stride, aux.z);
178+
accessor.get(ix + 3 * stride, aux.w);
179+
value = bit_cast<vector<Scalar, 3>, uint32_t4>(aux);
180+
}
121181

122182
template<typename Scalar>
123-
void set(const uint ix, const Scalar value) {accessor.set(ix, value);}
183+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const Scalar value) {accessor.set(ix, asuint(value));}
124184
template<typename Scalar>
125-
void set(const uint ix, const vector <Scalar, 2> value) {
126-
accessor.set(ix, value.x);
127-
accessor.set(ix + stride, value.y);
185+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const vector <Scalar, 2> value) {
186+
accessor.set(ix, asuint(value.x));
187+
accessor.set(ix + stride, asuint(value.y));
128188
}
129189
template<typename Scalar>
130-
void set(const uint ix, const vector <Scalar, 3> value) {
131-
accessor.set(ix, value.x);
132-
accessor.set(ix + stride, value.y);
133-
accessor.set(ix + 2 * stride, value.z);
190+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const <Scalar, 3> value) {
191+
accessor.set(ix, asuint(value.x));
192+
accessor.set(ix + stride, asuint(value.y));
193+
accessor.set(ix + 2 * stride, asuint(value.z));
134194
}
135195
template<typename Scalar>
136-
void set(const uint ix, const vector <Scalar, 4> value) {
137-
accessor.set(ix, value.x);
138-
accessor.set(ix + stride, value.y);
139-
accessor.set(ix + 2 * stride, value.z);
140-
accessor.set(ix + 3 * stride, value.w);
196+
enable_if_t<sizeof(Scalar) == sizeof(uint32_t), void> set(const uint ix, const <Scalar, 4> value) {
197+
accessor.set(ix, asuint(value.x));
198+
accessor.set(ix + stride, asuint(value.y));
199+
accessor.set(ix + 2 * stride, asuint(value.z));
200+
accessor.set(ix + 3 * stride, asuint(value.w));
141201
}
142202

143203
void atomicAnd(const uint ix, const uint value, NBL_REF_ARG(uint) orig) {

0 commit comments

Comments
 (0)