Skip to content

Commit e0013cb

Browse files
author
kevyuu
committed
Fix normal quantization cache
1 parent 75d486d commit e0013cb

File tree

1 file changed

+69
-40
lines changed

1 file changed

+69
-40
lines changed

include/nbl/asset/utils/CDirQuantCacheBase.h

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,23 @@ class CDirQuantCacheBase
4343

4444
Vector8u3() : x(0u),y(0u),z(0u) {}
4545
Vector8u3(const Vector8u3&) = default;
46-
explicit Vector8u3(const hlsl::float32_t3& val)
46+
explicit Vector8u3(const hlsl::uint32_t4& val)
4747
{
4848
operator=(val);
4949
}
5050

5151
Vector8u3& operator=(const Vector8u3&) = default;
52-
Vector8u3& operator=(const hlsl::float32_t3& val)
52+
Vector8u3& operator=(const hlsl::uint32_t4& val)
5353
{
5454
x = val.x;
5555
y = val.y;
5656
z = val.z;
5757
return *this;
5858
}
5959

60-
hlsl::float32_t3 getValue() const
60+
hlsl::uint32_t4 getValue() const
6161
{
62-
return { x, y, z };
62+
return { x, y, z, 0 };
6363
}
6464

6565

@@ -75,24 +75,24 @@ class CDirQuantCacheBase
7575

7676
Vector8u4() : x(0u),y(0u),z(0u),w(0u) {}
7777
Vector8u4(const Vector8u4&) = default;
78-
explicit Vector8u4(const hlsl::float32_t3& val)
78+
explicit Vector8u4(const hlsl::uint32_t4& val)
7979
{
8080
operator=(val);
8181
}
8282

8383
Vector8u4& operator=(const Vector8u4&) = default;
84-
Vector8u4& operator=(const hlsl::float32_t3& val)
84+
Vector8u4& operator=(const hlsl::uint32_t4& val)
8585
{
8686
x = val.x;
8787
y = val.y;
8888
z = val.z;
89-
w = 0;
89+
w = val.w;
9090
return *this;
9191
}
9292

93-
hlsl::float32_t3 getValue() const
93+
hlsl::uint32_t4 getValue() const
9494
{
95-
return { x, y, z };
95+
return { x, y, z, w };
9696
}
9797

9898
private:
@@ -109,17 +109,16 @@ class CDirQuantCacheBase
109109

110110
Vector1010102() : storage(0u) {}
111111
Vector1010102(const Vector1010102&) = default;
112-
explicit Vector1010102(const hlsl::float32_t3& val)
112+
explicit Vector1010102(const hlsl::uint32_t4& val)
113113
{
114114
operator=(val);
115115
}
116116

117117
Vector1010102& operator=(const Vector1010102&) = default;
118-
Vector1010102& operator=(const hlsl::float32_t3& val)
118+
Vector1010102& operator=(const hlsl::uint32_t4& val)
119119
{
120120
constexpr auto storageBits = quantizationBits + 1u;
121-
hlsl::uint32_t3 u32_val = { val.x, val.y, val.z };
122-
storage = u32_val.x | (u32_val.y << storageBits) | (u32_val.z << (storageBits * 2u));
121+
storage = val.x | (val.y << storageBits) | (val.z << (storageBits * 2u));
123122
return *this;
124123
}
125124

@@ -132,11 +131,11 @@ class CDirQuantCacheBase
132131
return storage==other.storage;
133132
}
134133

135-
hlsl::float32_t3 getValue() const
134+
hlsl::uint32_t4 getValue() const
136135
{
137136
constexpr auto storageBits = quantizationBits + 1u;
138137
const auto mask = (0x1u << storageBits) - 1u;
139-
return { storage & mask, (storage >> storageBits) & mask, (storage >> (storageBits * 2)) & mask};
138+
return { storage & mask, (storage >> storageBits) & mask, (storage >> (storageBits * 2)) & mask, 0};
140139
}
141140

142141
private:
@@ -151,23 +150,23 @@ class CDirQuantCacheBase
151150

152151
Vector16u3() : x(0u),y(0u),z(0u) {}
153152
Vector16u3(const Vector16u3&) = default;
154-
explicit Vector16u3(const hlsl::float32_t3& val)
153+
explicit Vector16u3(const hlsl::uint32_t4& val)
155154
{
156155
operator=(val);
157156
}
158157

159158
Vector16u3& operator=(const Vector16u3&) = default;
160-
Vector16u3& operator=(const hlsl::float32_t3& val)
159+
Vector16u3& operator=(const hlsl::uint32_t4& val)
161160
{
162161
x = val.x;
163162
y = val.y;
164163
z = val.z;
165164
return *this;
166165
}
167166

168-
hlsl::float32_t3 getValue() const
167+
hlsl::uint32_t4 getValue() const
169168
{
170-
return { x, y, z };
169+
return { x, y, z, 0 };
171170
}
172171

173172
private:
@@ -182,24 +181,24 @@ class CDirQuantCacheBase
182181

183182
Vector16u4() : x(0u),y(0u),z(0u),w(0u) {}
184183
Vector16u4(const Vector16u4&) = default;
185-
explicit Vector16u4(const hlsl::float32_t3& val)
184+
explicit Vector16u4(const hlsl::uint32_t4& val)
186185
{
187186
operator=(val);
188187
}
189188

190189
Vector16u4& operator=(const Vector16u4&) = default;
191-
Vector16u4& operator=(const hlsl::float32_t3& val)
190+
Vector16u4& operator=(const hlsl::uint32_t4& val)
192191
{
193192
x = val.x;
194193
y = val.y;
195194
z = val.z;
196-
w = 0;
195+
w = val.w;
197196
return *this;
198197
}
199198

200-
hlsl::float32_t3 getValue() const
199+
hlsl::float32_t4 getValue() const
201200
{
202-
return { x, y, z };
201+
return { x, y, z, w };
203202
}
204203

205204
private:
@@ -379,11 +378,28 @@ class CDirQuantCacheBase : public virtual core::IReferenceCounted, public impl::
379378
std::tuple<cache_type_t<Formats>...> cache;
380379

381380
template<uint32_t dimensions, E_FORMAT CacheFormat>
382-
value_type_t<CacheFormat> quantize(const hlsl::float32_t3& value)
381+
value_type_t<CacheFormat> quantize(const hlsl::vector<hlsl::float32_t, dimensions>& value)
383382
{
384-
const auto negativeMask = lessThan(value, hlsl::float32_t3(0.0f));
385-
386-
const hlsl::float32_t3 absValue = abs(value);
383+
auto to_float32_t4 = [](hlsl::vector<hlsl::float32_t, dimensions> src) -> hlsl::float32_t4
384+
{
385+
if constexpr(dimensions == 1)
386+
{
387+
return {src.x, 0, 0, 0};
388+
} else if constexpr (dimensions == 2)
389+
{
390+
return {src.x, src.y, 0, 0};
391+
} else if constexpr (dimensions == 3)
392+
{
393+
return {src.x, src.y, src.z, 0};
394+
} else if constexpr (dimensions == 4)
395+
{
396+
return {src.x, src.y, src.z, src.w};
397+
}
398+
};
399+
400+
const auto negativeMask = to_float32_t4(lessThan(value, hlsl::vector<hlsl::float32_t, dimensions>(0.0f)));
401+
402+
const hlsl::vector<hlsl::float32_t, dimensions> absValue = abs(value);
387403
const auto key = Key(absValue);
388404

389405
constexpr auto quantizationBits = quantization_bits_v<CacheFormat>;
@@ -397,29 +413,42 @@ class CDirQuantCacheBase : public virtual core::IReferenceCounted, public impl::
397413
{
398414
const auto fit = findBestFit<dimensions,quantizationBits>(absValue);
399415

400-
quantized = abs(fit);
416+
const auto abs_fit = to_float32_t4(abs(fit));
417+
quantized = hlsl::uint32_t4(abs_fit.x, abs_fit.y, abs_fit.z, abs_fit.w);
418+
401419
insertIntoCache<CacheFormat>(key,quantized);
402420
}
403421
}
404422

405-
//return quantized.
406-
const auto negativeMulVec = hlsl::float32_t3(negativeMask.x ? -1 : 1, negativeMask.y ? -1 : 1, negativeMask.z ? -1 : 1);
407-
return value_type_t<CacheFormat>(negativeMulVec * quantized.getValue());
423+
auto switch_vec = [](hlsl::uint32_t4 val1, hlsl::uint32_t4 val2, hlsl::bool4 mask)
424+
{
425+
hlsl::uint32_t4 retval;
426+
retval.x = mask.x ? val2.x : val1.x;
427+
retval.y = mask.y ? val2.y : val1.y;
428+
retval.z = mask.z ? val2.z : val1.z;
429+
retval.w = mask.w ? val2.w : val1.w;
430+
return retval;
431+
};
432+
433+
const hlsl::uint32_t4 xorflag((0x1u << (quantizationBits + 1u)) - 1u);
434+
auto restoredAsVec = quantized.getValue() ^ switch_vec(hlsl::uint32_t4(0u), hlsl::uint32_t4(xorflag), negativeMask);
435+
restoredAsVec += switch_vec(hlsl::uint32_t4(0u), hlsl::uint32_t4(1u), negativeMask);
436+
return value_type_t<CacheFormat>(restoredAsVec & xorflag);
408437
}
409438

410439
template<uint32_t dimensions, uint32_t quantizationBits>
411-
static inline hlsl::float32_t3 findBestFit(const hlsl::float32_t3& value)
440+
static inline hlsl::vector<hlsl::float32_t, dimensions> findBestFit(const hlsl::vector<hlsl::float32_t, dimensions>& value)
412441
{
413442
static_assert(dimensions>1u,"No point");
414443
static_assert(dimensions<=4u,"High Dimensions are Hard!");
415444

416445
const auto vectorForDots = hlsl::normalize(value);
417446

418447
//
419-
hlsl::float32_t3 fittingVector;
420-
hlsl::float32_t3 floorOffset;
448+
hlsl::vector<hlsl::float32_t, dimensions> fittingVector;
449+
hlsl::vector<hlsl::float32_t, dimensions> floorOffset;
421450
constexpr uint32_t cornerCount = (0x1u<<(dimensions-1u))-1u;
422-
hlsl::float32_t3 corners[cornerCount] = {};
451+
hlsl::vector<hlsl::float32_t, dimensions> corners[cornerCount] = {};
423452
{
424453
uint32_t maxDirCompIndex = 0u;
425454
for (auto i=1u; i<dimensions; i++)
@@ -431,7 +460,7 @@ class CDirQuantCacheBase : public virtual core::IReferenceCounted, public impl::
431460
if (maxDirectionComp < std::sqrtf(0.9998f / float(dimensions)))
432461
{
433462
_NBL_DEBUG_BREAK_IF(true);
434-
return hlsl::float32_t3(0.f);
463+
return hlsl::vector<hlsl::float32_t, dimensions>(0.f);
435464
}
436465
fittingVector = value / maxDirectionComp;
437466
floorOffset[maxDirCompIndex] = 0.499f;
@@ -453,9 +482,9 @@ class CDirQuantCacheBase : public virtual core::IReferenceCounted, public impl::
453482
}
454483
}
455484

456-
hlsl::float32_t3 bestFit;
485+
hlsl::vector<hlsl::float32_t, dimensions> bestFit;
457486
float closestTo1 = -1.f;
458-
auto evaluateFit = [&](const hlsl::float32_t3& newFit) -> void
487+
auto evaluateFit = [&](const hlsl::vector<hlsl::float32_t, dimensions>& newFit) -> void
459488
{
460489
auto newFitLen = length(newFit);
461490
const float dp = hlsl::dot(newFit,vectorForDots) / (newFitLen);
@@ -467,7 +496,7 @@ class CDirQuantCacheBase : public virtual core::IReferenceCounted, public impl::
467496
};
468497

469498
constexpr uint32_t cubeHalfSize = (0x1u << quantizationBits) - 1u;
470-
const hlsl::float32_t3 cubeHalfSizeND = hlsl::float32_t3(cubeHalfSize);
499+
const hlsl::vector<hlsl::float32_t, dimensions> cubeHalfSizeND = hlsl::vector<hlsl::float32_t, dimensions>(cubeHalfSize);
471500
for (uint32_t n=cubeHalfSize; n>0u; n--)
472501
{
473502
//we'd use float addition in the interest of speed, to increment the loop

0 commit comments

Comments
 (0)