Skip to content

Commit dc8f229

Browse files
committed
feat: literal decimal compare & add more test cases
1 parent 7aa5b4e commit dc8f229

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

src/iceberg/expression/literal.cc

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,42 @@ std::strong_ordering CompareFloat(T lhs, T rhs) {
344344
return lhs_is_negative <=> rhs_is_negative;
345345
}
346346

347+
std::strong_ordering CompareDecimal(Literal const& lhs, Literal const& rhs) {
348+
ICEBERG_DCHECK(std::holds_alternative<int128_t>(lhs.value()),
349+
"LHS of decimal comparison must hold int128_t");
350+
ICEBERG_DCHECK(std::holds_alternative<int128_t>(rhs.value()),
351+
"RHS of decimal comparison must hold int128_t");
352+
const auto& lhs_type = std::dynamic_pointer_cast<DecimalType>(lhs.type());
353+
const auto& rhs_type = std::dynamic_pointer_cast<DecimalType>(rhs.type());
354+
ICEBERG_DCHECK(lhs_type != nullptr, "LHS type must be DecimalType");
355+
ICEBERG_DCHECK(rhs_type != nullptr, "RHS type must be DecimalType");
356+
auto lhs_val = std::get<int128_t>(lhs.value());
357+
auto rhs_val = std::get<int128_t>(rhs.value());
358+
if (lhs_type->scale() == rhs_type->scale()) {
359+
return lhs_val <=> rhs_val;
360+
} else if (lhs_type->scale() > rhs_type->scale()) {
361+
auto lhs_decimal = Decimal(lhs_val);
362+
// Rescale to larger scale
363+
auto rhs_decimal = Decimal(rhs_val).Rescale(rhs_type->scale(), lhs_type->scale());
364+
if (!rhs_decimal) {
365+
// Rescale would cause data loss, so lhs is definitely less than rhs
366+
return std::strong_ordering::less;
367+
}
368+
return lhs_decimal <=> rhs_decimal.value();
369+
} else {
370+
auto rhs_decimal = Decimal(rhs_val);
371+
// Rescale to larger scale
372+
auto lhs_decimal = Decimal(lhs_val).Rescale(lhs_type->scale(), rhs_type->scale());
373+
if (!lhs_decimal) {
374+
// Rescale would cause data loss, so lhs is definitely greater than rhs
375+
return std::strong_ordering::greater;
376+
}
377+
return lhs_decimal.value() <=> rhs_decimal;
378+
}
379+
380+
return lhs_val <=> rhs_val;
381+
}
382+
347383
bool Literal::operator==(const Literal& other) const { return (*this <=> other) == 0; }
348384

349385
// Three-way comparison operator
@@ -410,10 +446,7 @@ std::partial_ordering Literal::operator<=>(const Literal& other) const {
410446
}
411447

412448
case TypeId::kDecimal: {
413-
// TODO(zhjwpku): Handle precision/scale differences
414-
auto this_val = std::get<int128_t>(value_);
415-
auto other_val = std::get<int128_t>(other.value_);
416-
return this_val <=> other_val;
449+
return CompareDecimal(*this, other);
417450
}
418451

419452
default:

test/decimal_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ TEST(DecimalTest, ToBigEndian) {
502502
INT64_MIN};
503503
std::vector<uint64_t> low_values = {0,
504504
1,
505+
255,
505506
UINT32_MAX,
506507
static_cast<uint64_t>(UINT32_MAX) + 1,
507508
static_cast<uint64_t>(UINT32_MAX) + 2,
@@ -522,6 +523,15 @@ TEST(DecimalTest, ToBigEndian) {
522523
EXPECT_EQ(result.value(), value);
523524
}
524525
}
526+
527+
for (int128_t value : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, -255, -1, 0, 1, 255,
528+
256, INT32_MAX, INT64_MAX}) {
529+
Decimal decimal(value);
530+
auto bytes = Decimal::ToBigEndian(decimal.value());
531+
auto result = Decimal::FromBigEndian(bytes.data(), bytes.size());
532+
ASSERT_THAT(result, IsOk());
533+
EXPECT_EQ(result.value(), decimal);
534+
}
525535
}
526536

527537
TEST(DecimalTestFunctionality, Multiply) {

test/literal_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,24 @@ TEST(LiteralTest, DoubleZeroComparison) {
384384
EXPECT_EQ(neg_zero <=> pos_zero, std::partial_ordering::less);
385385
}
386386

387+
TEST(LiteralTest, DecimalComparison) {
388+
auto dec1 = Literal::Decimal("123.45");
389+
auto dec2 = Literal::Decimal("123.450");
390+
auto dec3 = Literal::Decimal("123.46");
391+
auto dec4 = Literal::Decimal("-123.45");
392+
393+
ASSERT_THAT(dec1, IsOk());
394+
ASSERT_THAT(dec2, IsOk());
395+
ASSERT_THAT(dec3, IsOk());
396+
ASSERT_THAT(dec4, IsOk());
397+
398+
EXPECT_EQ((*dec1 <=> *dec2), std::partial_ordering::equivalent);
399+
EXPECT_EQ((*dec1 <=> *dec3), std::partial_ordering::less);
400+
EXPECT_EQ((*dec3 <=> *dec1), std::partial_ordering::greater);
401+
EXPECT_EQ((*dec1 <=> *dec4), std::partial_ordering::greater);
402+
EXPECT_EQ((*dec4 <=> *dec1), std::partial_ordering::less);
403+
}
404+
387405
TEST(LiteralTest, SerdeTest) {
388406
// int32
389407
auto int_literal = Literal::Int(42);

0 commit comments

Comments
 (0)