Skip to content

Commit 7aa5b4e

Browse files
committed
feat: Literal support decimal & Literal serde
1 parent 3b945a0 commit 7aa5b4e

File tree

7 files changed

+326
-8
lines changed

7 files changed

+326
-8
lines changed

src/iceberg/expression/literal.cc

Lines changed: 169 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@
2121

2222
#include <cmath>
2323
#include <concepts>
24+
#include <utility>
2425

2526
#include "iceberg/exception.h"
27+
#include "iceberg/result.h"
28+
#include "iceberg/util/decimal.h"
29+
#include "iceberg/util/endian.h"
30+
#include "iceberg/util/macros.h"
2631

2732
namespace iceberg {
2833

@@ -149,13 +154,168 @@ Literal Literal::Binary(std::vector<uint8_t> value) {
149154
return {Value{std::move(value)}, binary()};
150155
}
151156

157+
Literal Literal::Decimal(int128_t value, int32_t precision, int32_t scale) {
158+
return {Value{value}, decimal(precision, scale)};
159+
}
160+
161+
Result<Literal> Literal::Decimal(std::string_view value) {
162+
int32_t precision = 0;
163+
int32_t scale = 0;
164+
ICEBERG_ASSIGN_OR_RAISE(auto decimal_value,
165+
Decimal::FromString(value, &precision, &scale));
166+
return Literal{Value{decimal_value.value()}, decimal(precision, scale)};
167+
}
168+
152169
Result<Literal> Literal::Deserialize(std::span<const uint8_t> data,
153170
std::shared_ptr<PrimitiveType> type) {
154-
return NotImplemented("Deserialization of Literal is not implemented yet");
171+
Literal::Value value;
172+
switch (type->type_id()) {
173+
case TypeId::kBoolean:
174+
if (data.size() == 1 && data[0] == 1) {
175+
value = true;
176+
} else {
177+
value = false;
178+
}
179+
break;
180+
case TypeId::kInt:
181+
case TypeId::kDate:
182+
if (data.size() != sizeof(int32_t)) {
183+
return Invalid("Invalid data size for Int literal deserialization");
184+
}
185+
value = FromLittleEndian(*reinterpret_cast<const int32_t*>(data.data()));
186+
break;
187+
case TypeId::kLong:
188+
// In the case of an evolved field
189+
if (data.size() == sizeof(int32_t)) {
190+
value = static_cast<int64_t>(
191+
FromLittleEndian(*reinterpret_cast<const int32_t*>(data.data())));
192+
} else if (data.size() == sizeof(int64_t)) {
193+
value = FromLittleEndian(*reinterpret_cast<const int64_t*>(data.data()));
194+
} else {
195+
return Invalid("Invalid data size for Long literal deserialization");
196+
}
197+
break;
198+
case TypeId::kFloat:
199+
if (data.size() != sizeof(float)) {
200+
return Invalid("Invalid data size for Float literal deserialization");
201+
}
202+
value = FromLittleEndian(*reinterpret_cast<const float*>(data.data()));
203+
break;
204+
case TypeId::kDouble:
205+
// In the case of an evolved field
206+
if (data.size() == sizeof(float)) {
207+
value = static_cast<double>(
208+
FromLittleEndian(*reinterpret_cast<const float*>(data.data())));
209+
} else if (data.size() == sizeof(double)) {
210+
value = FromLittleEndian(*reinterpret_cast<const double*>(data.data()));
211+
} else {
212+
return Invalid("Invalid data size for Double literal deserialization");
213+
}
214+
break;
215+
case TypeId::kTime:
216+
case TypeId::kTimestamp:
217+
case TypeId::kTimestampTz:
218+
if (data.size() != sizeof(int64_t)) {
219+
return Invalid("Invalid data size for Timestamp/Time literal deserialization");
220+
}
221+
value = FromLittleEndian(*reinterpret_cast<const int64_t*>(data.data()));
222+
break;
223+
case TypeId::kString:
224+
value = std::string(data.begin(), data.end());
225+
break;
226+
case TypeId::kUuid:
227+
if (data.size() != 16) {
228+
return Invalid("Invalid data size for UUID literal deserialization");
229+
}
230+
value = *reinterpret_cast<const std::array<uint8_t, 16>*>(data.data());
231+
break;
232+
case TypeId::kDecimal: {
233+
ICEBERG_ASSIGN_OR_RAISE(auto unscaled_decimal,
234+
Decimal::FromBigEndian(data.data(), data.size()));
235+
value = unscaled_decimal.value();
236+
} break;
237+
case TypeId::kFixed:
238+
case TypeId::kBinary:
239+
value = std::vector<uint8_t>(data.begin(), data.end());
240+
break;
241+
default:
242+
std::unreachable();
243+
}
244+
245+
return Literal(value, std::move(type));
155246
}
156247

157248
Result<std::vector<uint8_t>> Literal::Serialize() const {
158-
return NotImplemented("Serialization of Literal is not implemented yet");
249+
if (IsAboveMax() || IsBelowMin()) {
250+
return Invalid("Cannot serialize AboveMax or BelowMin literal");
251+
}
252+
if (IsNull()) {
253+
return std::vector<uint8_t>{};
254+
}
255+
256+
switch (type_->type_id()) {
257+
case TypeId::kBoolean: {
258+
bool bool_val = std::get<bool>(value_);
259+
return std::vector<uint8_t>{static_cast<uint8_t>(bool_val ? 1 : 0)};
260+
}
261+
case TypeId::kInt:
262+
case TypeId::kDate: {
263+
int32_t int_val = std::get<int32_t>(value_);
264+
int32_t le_val = ToLittleEndian(int_val);
265+
const auto* bytes =
266+
reinterpret_cast<const uint8_t*>(static_cast<const void*>(&le_val));
267+
return std::vector<uint8_t>(bytes, bytes + sizeof(int32_t));
268+
}
269+
case TypeId::kLong: {
270+
int64_t long_val = std::get<int64_t>(value_);
271+
int64_t le_val = ToLittleEndian(long_val);
272+
const auto* bytes =
273+
reinterpret_cast<const uint8_t*>(static_cast<const void*>(&le_val));
274+
return std::vector<uint8_t>(bytes, bytes + sizeof(int64_t));
275+
}
276+
case TypeId::kFloat: {
277+
float float_val = std::get<float>(value_);
278+
float le_val = ToLittleEndian(float_val);
279+
const auto* bytes =
280+
reinterpret_cast<const uint8_t*>(static_cast<const void*>(&le_val));
281+
return std::vector<uint8_t>(bytes, bytes + sizeof(float));
282+
}
283+
case TypeId::kDouble: {
284+
double double_val = std::get<double>(value_);
285+
double le_val = ToLittleEndian(double_val);
286+
const auto* bytes =
287+
reinterpret_cast<const uint8_t*>(static_cast<const void*>(&le_val));
288+
return std::vector<uint8_t>(bytes, bytes + sizeof(double));
289+
}
290+
case TypeId::kTime:
291+
case TypeId::kTimestamp:
292+
case TypeId::kTimestampTz: {
293+
int64_t time_val = std::get<int64_t>(value_);
294+
int64_t le_val = ToLittleEndian(time_val);
295+
const auto* bytes =
296+
reinterpret_cast<const uint8_t*>(static_cast<const void*>(&le_val));
297+
return std::vector<uint8_t>(bytes, bytes + sizeof(int64_t));
298+
}
299+
case TypeId::kString: {
300+
const auto& str_val = std::get<std::string>(value_);
301+
return std::vector<uint8_t>(str_val.begin(), str_val.end());
302+
}
303+
case TypeId::kUuid: {
304+
const auto& uuid_val = std::get<std::array<uint8_t, 16>>(value_);
305+
return std::vector<uint8_t>(uuid_val.begin(), uuid_val.end());
306+
}
307+
case TypeId::kDecimal: {
308+
int128_t decimal_val = std::get<int128_t>(value_);
309+
return Decimal::ToBigEndian(decimal_val);
310+
}
311+
case TypeId::kFixed:
312+
case TypeId::kBinary: {
313+
const auto& bin_val = std::get<std::vector<uint8_t>>(value_);
314+
return bin_val;
315+
}
316+
default:
317+
std::unreachable();
318+
}
159319
}
160320

161321
// Getters
@@ -249,6 +409,13 @@ std::partial_ordering Literal::operator<=>(const Literal& other) const {
249409
return this_val <=> other_val;
250410
}
251411

412+
case TypeId::kDecimal: {
413+
// TODO(zhjwpku): Handle precision/scale differences
414+
auto this_val = std::get<int128_t>(value_);
415+
auto other_val = std::get<int128_t>(other.value_);
416+
return this_val <=> other_val;
417+
}
418+
252419
default:
253420
// For unsupported types, return unordered
254421
return std::partial_ordering::unordered;

src/iceberg/expression/literal.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
#include <compare>
2323
#include <memory>
2424
#include <string>
25+
#include <string_view>
2526
#include <variant>
2627
#include <vector>
2728

2829
#include "iceberg/result.h"
2930
#include "iceberg/type.h"
31+
#include "iceberg/util/int128.h"
3032

3133
namespace iceberg {
3234

@@ -56,7 +58,8 @@ class ICEBERG_EXPORT Literal {
5658
double, // for double
5759
std::string, // for string
5860
std::vector<uint8_t>, // for binary, fixed
59-
std::array<uint8_t, 16>, // for uuid and decimal
61+
std::array<uint8_t, 16>, // for uuid
62+
int128_t, // for decimal
6063
BelowMin, AboveMax>;
6164

6265
/// \brief Factory methods for primitive types
@@ -71,6 +74,8 @@ class ICEBERG_EXPORT Literal {
7174
static Literal Double(double value);
7275
static Literal String(std::string value);
7376
static Literal Binary(std::vector<uint8_t> value);
77+
static Literal Decimal(int128_t value, int32_t precision, int32_t scale);
78+
static Result<Literal> Decimal(std::string_view value);
7479

7580
/// \brief Create a literal representing a null value.
7681
static Literal Null(std::shared_ptr<PrimitiveType> type) {

src/iceberg/transform_function.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include "iceberg/expression/literal.h"
2929
#include "iceberg/type.h"
30+
#include "iceberg/util/int128.h"
3031
#include "iceberg/util/murmurhash3_internal.h"
3132
#include "iceberg/util/truncate_util.h"
3233

@@ -73,6 +74,8 @@ Result<Literal> BucketTransform::Transform(const Literal& literal) {
7374
MurmurHash3_x86_32(&value, sizeof(int64_t), 0, &hash_value);
7475
} else if constexpr (std::is_same_v<T, std::array<uint8_t, 16>>) {
7576
MurmurHash3_x86_32(value.data(), sizeof(uint8_t) * 16, 0, &hash_value);
77+
} else if constexpr (std::is_same_v<T, int128_t>) {
78+
MurmurHash3_x86_32(&value, sizeof(int128_t), 0, &hash_value);
7679
} else if constexpr (std::is_same_v<T, std::string>) {
7780
MurmurHash3_x86_32(value.data(), value.size(), 0, &hash_value);
7881
} else if constexpr (std::is_same_v<T, std::vector<uint8_t>>) {

src/iceberg/util/decimal.cc

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "iceberg/util/decimal.h"
2626

27+
#include <algorithm>
2728
#include <bit>
2829
#include <charconv>
2930
#include <climits>
@@ -44,6 +45,9 @@ namespace iceberg {
4445

4546
namespace {
4647

48+
static constexpr int32_t kMinDecimalBytes = 1;
49+
static constexpr int32_t kMaxDecimalBytes = 16;
50+
4751
struct DecimalComponents {
4852
std::string_view while_digits;
4953
std::string_view fractional_digits;
@@ -472,11 +476,6 @@ Result<Decimal> Decimal::FromString(std::string_view str, int32_t* precision,
472476
}
473477

474478
Result<Decimal> Decimal::FromBigEndian(const uint8_t* bytes, int32_t length) {
475-
static constexpr int32_t kMinDecimalBytes = 1;
476-
static constexpr int32_t kMaxDecimalBytes = 16;
477-
478-
int64_t high, low;
479-
480479
if (length < kMinDecimalBytes || length > kMaxDecimalBytes) {
481480
return InvalidArgument(
482481
"Decimal::FromBigEndian: length must be in the range [{}, {}], was {}",
@@ -507,6 +506,36 @@ Result<Decimal> Decimal::FromBigEndian(const uint8_t* bytes, int32_t length) {
507506
return Decimal(static_cast<int128_t>(result));
508507
}
509508

509+
std::vector<uint8_t> Decimal::ToBigEndian(int128_t value) {
510+
std::vector<uint8_t> bytes(kMaxDecimalBytes);
511+
512+
auto uvalue = static_cast<uint128_t>(value);
513+
std::memcpy(bytes.data(), &uvalue, 16);
514+
515+
if constexpr (std::endian::native == std::endian::little) {
516+
std::ranges::reverse(bytes);
517+
}
518+
519+
auto is_negative = value < 0;
520+
int keep = kMaxDecimalBytes;
521+
for (int32_t i = 0; i < kMaxDecimalBytes - 1; ++i) {
522+
uint8_t byte = bytes[i];
523+
uint8_t next = bytes[i + 1];
524+
// For negative numbers, keep the leading 0xff byte if the next byte has its sign bit
525+
// unset. For positive numbers, keep the leading 0x00 byte if the next byte has its
526+
// sign bit set.
527+
if ((is_negative && byte == 0xff && (next & 0x80) == 0) ||
528+
(!is_negative && byte == 0x00 && (next & 0x80) != 0)) {
529+
--keep;
530+
} else {
531+
break;
532+
}
533+
}
534+
535+
bytes.erase(bytes.begin(), bytes.begin() + (kMaxDecimalBytes - keep));
536+
return bytes;
537+
}
538+
510539
Result<Decimal> Decimal::Rescale(int32_t orig_scale, int32_t new_scale) const {
511540
if (orig_scale == new_scale) {
512541
return *this;

src/iceberg/util/decimal.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <string>
3131
#include <string_view>
3232
#include <type_traits>
33+
#include <vector>
3334

3435
#include "iceberg/iceberg_export.h"
3536
#include "iceberg/result.h"
@@ -164,6 +165,12 @@ class ICEBERG_EXPORT Decimal : public util::Formattable {
164165
/// \return error status if the length is an invalid value
165166
static Result<Decimal> FromBigEndian(const uint8_t* data, int32_t length);
166167

168+
/// \brief Convert Decimal's unscaled value to two’s-complement big-endian binary, using
169+
/// the minimum number of bytes for the value.
170+
/// \param value The unscaled value.
171+
/// \return A vector containing the big-endian bytes.
172+
static std::vector<uint8_t> ToBigEndian(int128_t value);
173+
167174
/// \brief Convert Decimal from one scale to another.
168175
Result<Decimal> Rescale(int32_t orig_scale, int32_t new_scale) const;
169176

test/decimal_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,40 @@ TEST(DecimalTest, FromBigEndianInvalid) {
490490
IsError(ErrorKind::kInvalidArgument));
491491
}
492492

493+
TEST(DecimalTest, ToBigEndian) {
494+
std::vector<int64_t> high_values = {0,
495+
1,
496+
-1,
497+
INT32_MAX,
498+
INT32_MIN,
499+
static_cast<int64_t>(INT32_MAX) + 1,
500+
static_cast<int64_t>(INT32_MIN) - 1,
501+
INT64_MAX,
502+
INT64_MIN};
503+
std::vector<uint64_t> low_values = {0,
504+
1,
505+
UINT32_MAX,
506+
static_cast<uint64_t>(UINT32_MAX) + 1,
507+
static_cast<uint64_t>(UINT32_MAX) + 2,
508+
static_cast<uint64_t>(UINT32_MAX) + 3,
509+
static_cast<uint64_t>(UINT32_MAX) + 4,
510+
static_cast<uint64_t>(UINT32_MAX) + 5,
511+
static_cast<uint64_t>(UINT32_MAX) + 6,
512+
static_cast<uint64_t>(UINT32_MAX) + 7,
513+
static_cast<uint64_t>(UINT32_MAX) + 8,
514+
UINT64_MAX};
515+
516+
for (int64_t high : high_values) {
517+
for (uint64_t low : low_values) {
518+
Decimal value(high, low);
519+
auto bytes = Decimal::ToBigEndian(value.value());
520+
auto result = Decimal::FromBigEndian(bytes.data(), bytes.size());
521+
ASSERT_THAT(result, IsOk());
522+
EXPECT_EQ(result.value(), value);
523+
}
524+
}
525+
}
526+
493527
TEST(DecimalTestFunctionality, Multiply) {
494528
ASSERT_EQ(Decimal(60501), Decimal(301) * Decimal(201));
495529
ASSERT_EQ(Decimal(-60501), Decimal(-301) * Decimal(201));

0 commit comments

Comments
 (0)