diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index a9c1a60bf..12bb3f03f 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -19,11 +19,16 @@ #include "iceberg/expression/aggregate.h" +#include #include +#include +#include #include +#include #include #include "iceberg/expression/literal.h" +#include "iceberg/manifest/manifest_entry.h" #include "iceberg/row/struct_like.h" #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" @@ -38,6 +43,32 @@ std::shared_ptr GetPrimitiveType(const BoundTerm& term) { return internal::checked_pointer_cast(term.type()); } +/// \brief A single-field StructLike that wraps a Literal +class SingleValueStructLike : public StructLike { + public: + explicit SingleValueStructLike(Literal literal) : literal_(std::move(literal)) {} + + Result GetField(size_t) const override { return LiteralToScalar(literal_); } + + size_t num_fields() const override { return 1; } + + private: + Literal literal_; +}; + +Result EvaluateBoundTerm(const BoundTerm& term, + const std::optional>& bound) { + auto ptype = GetPrimitiveType(term); + if (!bound.has_value()) { + SingleValueStructLike data(Literal::Null(ptype)); + return term.Evaluate(data); + } + + ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype)); + SingleValueStructLike data(std::move(literal)); + return term.Evaluate(data); +} + class CountAggregator : public BoundAggregate::Aggregator { public: explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {} @@ -48,11 +79,32 @@ class CountAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return Literal::Long(count_); } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file)); + count_ += count; + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(int64()); + } + return Literal::Long(count_); + } + + bool IsValid() const override { return valid_; } private: const CountAggregate& aggregate_; int64_t count_ = 0; + bool valid_ = true; }; class MaxAggregator : public BoundAggregate::Aggregator { @@ -73,6 +125,7 @@ class MaxAggregator : public BoundAggregate::Aggregator { if (auto ordering = value <=> current_; ordering == std::partial_ordering::unordered) { + valid_ = false; return InvalidArgument("Cannot compare literal {} with current value {}", value.ToString(), current_.ToString()); } else if (ordering == std::partial_ordering::greater) { @@ -82,11 +135,48 @@ class MaxAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return current_; } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + valid_ = false; + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::greater) { + current_ = std::move(value); + } + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(GetPrimitiveType(*aggregate_.term())); + } + return current_; + } + + bool IsValid() const override { return valid_; } private: const MaxAggregate& aggregate_; Literal current_; + bool valid_ = true; }; class MinAggregator : public BoundAggregate::Aggregator { @@ -107,6 +197,7 @@ class MinAggregator : public BoundAggregate::Aggregator { if (auto ordering = value <=> current_; ordering == std::partial_ordering::unordered) { + valid_ = false; return InvalidArgument("Cannot compare literal {} with current value {}", value.ToString(), current_.ToString()); } else if (ordering == std::partial_ordering::less) { @@ -115,13 +206,66 @@ class MinAggregator : public BoundAggregate::Aggregator { return {}; } - Literal GetResult() const override { return current_; } + Status Update(const DataFile& file) override { + if (!valid_) { + return {}; + } + if (!aggregate_.HasValue(file)) { + valid_ = false; + return {}; + } + + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + valid_ = false; + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::less) { + current_ = std::move(value); + } + return {}; + } + + Literal GetResult() const override { + if (!valid_) { + return Literal::Null(GetPrimitiveType(*aggregate_.term())); + } + return current_; + } + + bool IsValid() const override { return valid_; } private: const MinAggregate& aggregate_; Literal current_; + bool valid_ = true; }; +template +std::optional GetMapValue(const std::map& map, int32_t key) { + auto iter = map.find(key); + if (iter == map.end()) { + return std::nullopt; + } + return iter->second; +} + +int32_t GetFieldId(const std::shared_ptr& term) { + ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null"); + auto ref = term->reference(); + ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be null"); + return ref->field().field_id(); +} + } // namespace template @@ -149,7 +293,11 @@ std::string Aggregate::ToString() const { // -------------------- CountAggregate -------------------- Result CountAggregate::Evaluate(const StructLike& data) const { - return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); + return CountFor(data).transform(Literal::Long); +} + +Result CountAggregate::Evaluate(const DataFile& file) const { + return CountFor(file).transform(Literal::Long); } std::unique_ptr CountAggregate::NewAggregator() const { @@ -173,6 +321,22 @@ Result CountNonNullAggregate::CountFor(const StructLike& data) const { [](const auto& val) { return val.IsNull() ? 0 : 1; }); } +Result CountNonNullAggregate::CountFor(const DataFile& file) const { + auto field_id = GetFieldId(term()); + if (!HasValue(file)) { + return NotFound("Missing metrics for field id {}", field_id); + } + auto value_count = GetMapValue(file.value_counts, field_id).value(); + auto null_count = GetMapValue(file.null_value_counts, field_id).value(); + return value_count - null_count; +} + +bool CountNonNullAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + return file.value_counts.contains(field_id) && + file.null_value_counts.contains(field_id); +} + CountNullAggregate::CountNullAggregate(std::shared_ptr term) : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} @@ -189,6 +353,18 @@ Result CountNullAggregate::CountFor(const StructLike& data) const { [](const auto& val) { return val.IsNull() ? 1 : 0; }); } +Result CountNullAggregate::CountFor(const DataFile& file) const { + auto field_id = GetFieldId(term()); + if (!HasValue(file)) { + return NotFound("Missing metrics for field id {}", field_id); + } + return GetMapValue(file.null_value_counts, field_id).value(); +} + +bool CountNullAggregate::HasValue(const DataFile& file) const { + return file.null_value_counts.contains(GetFieldId(term())); +} + CountStarAggregate::CountStarAggregate() : CountAggregate(Expression::Operation::kCountStar, nullptr) {} @@ -200,36 +376,93 @@ Result CountStarAggregate::CountFor(const StructLike& /*data*/) const { return 1; } +Result CountStarAggregate::CountFor(const DataFile& file) const { + if (!HasValue(file)) { + return NotFound("Record count is missing"); + } + return file.record_count; +} + +bool CountStarAggregate::HasValue(const DataFile& file) const { + return file.record_count >= 0; +} + MaxAggregate::MaxAggregate(std::shared_ptr term) : BoundAggregate(Expression::Operation::kMax, std::move(term)) {} -std::shared_ptr MaxAggregate::Make(std::shared_ptr term) { - return std::shared_ptr(new MaxAggregate(std::move(term))); +Result> MaxAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound max aggregate requires non-null term"); + } + if (!term->type()->is_primitive()) { + return InvalidExpression("Max aggregate term should be primitive"); + } + return std::unique_ptr(new MaxAggregate(std::move(term))); } Result MaxAggregate::Evaluate(const StructLike& data) const { return term()->Evaluate(data); } +Result MaxAggregate::Evaluate(const DataFile& file) const { + auto field_id = GetFieldId(term()); + auto upper = GetMapValue(file.upper_bounds, field_id); + return EvaluateBoundTerm(*term(), upper); +} + std::unique_ptr MaxAggregate::NewAggregator() const { return std::unique_ptr(new MaxAggregator(*this)); } +bool MaxAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + bool has_bound = file.upper_bounds.contains(field_id); + auto value_count = GetMapValue(file.value_counts, field_id); + auto null_count = GetMapValue(file.null_value_counts, field_id); + bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && + null_count.value() == value_count.value(); + return has_bound || all_null; +} + MinAggregate::MinAggregate(std::shared_ptr term) : BoundAggregate(Expression::Operation::kMin, std::move(term)) {} -std::shared_ptr MinAggregate::Make(std::shared_ptr term) { - return std::shared_ptr(new MinAggregate(std::move(term))); +Result> MinAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound min aggregate requires non-null term"); + } + if (!term->type()->is_primitive()) { + return InvalidExpression("Max aggregate term should be primitive"); + } + return std::unique_ptr(new MinAggregate(std::move(term))); } Result MinAggregate::Evaluate(const StructLike& data) const { return term()->Evaluate(data); } +Result MinAggregate::Evaluate(const DataFile& file) const { + auto field_id = GetFieldId(term()); + auto lower = GetMapValue(file.lower_bounds, field_id); + return EvaluateBoundTerm(*term(), lower); +} + std::unique_ptr MinAggregate::NewAggregator() const { return std::unique_ptr(new MinAggregator(*this)); } +bool MinAggregate::HasValue(const DataFile& file) const { + auto field_id = GetFieldId(term()); + bool has_bound = file.lower_bounds.contains(field_id); + auto value_count = GetMapValue(file.value_counts, field_id); + auto null_count = GetMapValue(file.null_value_counts, field_id); + bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && + null_count.value() == value_count.value(); + return has_bound || all_null; +} + // -------------------- Unbound binding -------------------- template @@ -275,8 +508,10 @@ Result>> UnboundAggregateImpl::Make( } template class Aggregate>; +template class Aggregate>; template class Aggregate; template class UnboundAggregateImpl; +template class UnboundAggregateImpl; // -------------------- AggregateEvaluator -------------------- @@ -296,6 +531,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { return {}; } + Status Update(const DataFile& file) override { + for (auto& aggregator : aggregators_) { + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file)); + } + return {}; + } + Result> GetResults() const override { results_.clear(); results_.reserve(aggregates_.size()); @@ -315,6 +557,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { return all.front(); } + bool AllAggregatorsValid() const override { + return std::ranges::all_of(aggregators_, &BoundAggregate::Aggregator::IsValid); + } + private: std::vector> aggregates_; std::vector> aggregators_; diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index cde9e4583..6cf659d6f 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -109,14 +109,15 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound virtual Status Update(const StructLike& data) = 0; - virtual Status Update(const DataFile& file) { - return NotImplemented("Update(DataFile) not implemented"); - } + virtual Status Update(const DataFile& file) = 0; + + /// \brief Whether the aggregator is still valid. + virtual bool IsValid() const = 0; /// \brief Get the result of the aggregation. /// \return The result of the aggregation. /// \note It is an undefined behavior to call this method if any previous Update call - /// has returned an error. + /// has returned an error or if IsValid() returns false. virtual Literal GetResult() const = 0; }; @@ -128,6 +129,11 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound Result Evaluate(const StructLike& data) const override = 0; + virtual Result Evaluate(const DataFile& file) const = 0; + + /// \brief Whether metrics in the data file are sufficient to evaluate. + virtual bool HasValue(const DataFile& file) const = 0; + bool is_bound_aggregate() const override { return true; } /// \brief Create a new aggregator for this aggregate. @@ -142,12 +148,15 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound /// \brief Base class for COUNT aggregates. class ICEBERG_EXPORT CountAggregate : public BoundAggregate { public: - Result Evaluate(const StructLike& data) const final; + Result Evaluate(const StructLike& data) const override; + Result Evaluate(const DataFile& file) const override; std::unique_ptr NewAggregator() const override; /// \brief Count for a single row. Subclasses implement this. virtual Result CountFor(const StructLike& data) const = 0; + /// \brief Count using metrics from a data file. + virtual Result CountFor(const DataFile& file) const = 0; protected: CountAggregate(Expression::Operation op, std::shared_ptr term) @@ -161,6 +170,8 @@ class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { std::shared_ptr term); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + bool HasValue(const DataFile& file) const override; private: explicit CountNonNullAggregate(std::shared_ptr term); @@ -173,6 +184,8 @@ class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { std::shared_ptr term); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + bool HasValue(const DataFile& file) const override; private: explicit CountNullAggregate(std::shared_ptr term); @@ -184,6 +197,8 @@ class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { static Result> Make(); Result CountFor(const StructLike& data) const override; + Result CountFor(const DataFile& file) const override; + bool HasValue(const DataFile& file) const override; private: CountStarAggregate(); @@ -192,9 +207,11 @@ class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { /// \brief Bound MAX aggregate. class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { public: - static std::shared_ptr Make(std::shared_ptr term); + static Result> Make(std::shared_ptr term); Result Evaluate(const StructLike& data) const override; + Result Evaluate(const DataFile& file) const override; + bool HasValue(const DataFile& file) const override; std::unique_ptr NewAggregator() const override; @@ -205,9 +222,11 @@ class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { /// \brief Bound MIN aggregate. class ICEBERG_EXPORT MinAggregate : public BoundAggregate { public: - static std::shared_ptr Make(std::shared_ptr term); + static Result> Make(std::shared_ptr term); Result Evaluate(const StructLike& data) const override; + Result Evaluate(const DataFile& file) const override; + bool HasValue(const DataFile& file) const override; std::unique_ptr NewAggregator() const override; @@ -234,11 +253,17 @@ class ICEBERG_EXPORT AggregateEvaluator { /// \brief Update aggregates with a row. virtual Status Update(const StructLike& data) = 0; + /// \brief Update aggregates using data file metrics. + virtual Status Update(const DataFile& file) = 0; + /// \brief Final aggregated value. virtual Result> GetResults() const = 0; /// \brief Convenience accessor when only one aggregate is evaluated. virtual Result GetResult() const = 0; + + /// \brief Whether all aggregators are still valid (metrics present). + virtual bool AllAggregatorsValid() const = 0; }; } // namespace iceberg diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index 7eef60232..4b0e538ae 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -138,6 +138,13 @@ std::shared_ptr> Expressions::Max( return agg; } +std::shared_ptr> Expressions::Max( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMax, std::move(expr))); + return agg; +} + std::shared_ptr> Expressions::Min(std::string name) { return Min(Ref(std::move(name))); } @@ -149,6 +156,13 @@ std::shared_ptr> Expressions::Min( return agg; } +std::shared_ptr> Expressions::Min( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMin, std::move(expr))); + return agg; +} + // Template implementations for unary predicates std::shared_ptr> Expressions::IsNull( diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index 92c523ca7..4ef2e7800 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -135,6 +135,9 @@ class ICEBERG_EXPORT Expressions { /// \brief Create a MAX aggregate for an unbound term. static std::shared_ptr> Max( std::shared_ptr> expr); + /// \brief Create a MAX aggregate for an unbound transform term. + static std::shared_ptr> Max( + std::shared_ptr> expr); /// \brief Create a MIN aggregate for a field name. static std::shared_ptr> Min(std::string name); @@ -142,6 +145,9 @@ class ICEBERG_EXPORT Expressions { /// \brief Create a MIN aggregate for an unbound term. static std::shared_ptr> Min( std::shared_ptr> expr); + /// \brief Create a MIN aggregate for an unbound transform term. + static std::shared_ptr> Min( + std::shared_ptr> expr); // Unary predicates diff --git a/src/iceberg/expression/term.h b/src/iceberg/expression/term.h index 8b9606e56..616f11da6 100644 --- a/src/iceberg/expression/term.h +++ b/src/iceberg/expression/term.h @@ -37,10 +37,13 @@ namespace iceberg { /// \brief A term is an expression node that produces a typed value when evaluated. class ICEBERG_EXPORT Term : public util::Formattable { public: - enum class Kind : uint8_t { kReference = 0, kTransform, kExtract }; + enum class Kind : uint8_t { kReference, kTransform, kExtract }; /// \brief Returns the kind of this term. virtual Kind kind() const = 0; + + /// \brief Returns whether this term is unbound. + virtual bool is_unbound() const = 0; }; template @@ -53,6 +56,8 @@ template class ICEBERG_EXPORT UnboundTerm : public Unbound, public Term { public: using BoundType = B; + + bool is_unbound() const override { return true; } }; /// \brief Base class for bound terms. @@ -66,8 +71,6 @@ class ICEBERG_EXPORT BoundTerm : public Bound, public Term { /// \brief Returns whether this term may produce null values. virtual bool MayProduceNull() const = 0; - // TODO(gangwu): add a comparator function to Literal and BoundTerm. - /// \brief Returns whether this term is equivalent to another. /// /// Two terms are equivalent if they produce the same values when evaluated. @@ -79,6 +82,8 @@ class ICEBERG_EXPORT BoundTerm : public Bound, public Term { friend bool operator==(const BoundTerm& lhs, const BoundTerm& rhs) { return lhs.Equals(rhs); } + + bool is_unbound() const override { return false; } }; /// \brief A reference represents a named field in an expression. diff --git a/src/iceberg/row/struct_like.cc b/src/iceberg/row/struct_like.cc index b0fb67fb4..85bde1a69 100644 --- a/src/iceberg/row/struct_like.cc +++ b/src/iceberg/row/struct_like.cc @@ -19,7 +19,9 @@ #include "iceberg/row/struct_like.h" +#include #include +#include #include "iceberg/result.h" #include "iceberg/util/checked_cast.h" @@ -28,6 +30,44 @@ namespace iceberg { +Result LiteralToScalar(const Literal& literal) { + if (literal.IsNull()) { + return Scalar{std::monostate{}}; + } + + switch (literal.type()->type_id()) { + case TypeId::kBoolean: + return Scalar{std::get(literal.value())}; + case TypeId::kInt: + case TypeId::kDate: + return Scalar{std::get(literal.value())}; + case TypeId::kLong: + case TypeId::kTime: + case TypeId::kTimestamp: + case TypeId::kTimestampTz: + return Scalar{std::get(literal.value())}; + case TypeId::kFloat: + return Scalar{std::get(literal.value())}; + case TypeId::kDouble: + return Scalar{std::get(literal.value())}; + case TypeId::kString: { + const auto& str = std::get(literal.value()); + return Scalar{std::string_view(str)}; + } + case TypeId::kBinary: + case TypeId::kFixed: { + const auto& bytes = std::get>(literal.value()); + return Scalar{ + std::string_view(reinterpret_cast(bytes.data()), bytes.size())}; + } + case TypeId::kDecimal: + return Scalar{std::get(literal.value())}; + default: + return NotSupported("Cannot convert literal of type {} to Scalar", + literal.type()->ToString()); + } +} + StructLikeAccessor::StructLikeAccessor(std::shared_ptr type, std::span position_path) : type_(std::move(type)) { diff --git a/src/iceberg/row/struct_like.h b/src/iceberg/row/struct_like.h index 4999da69e..36ff5d86b 100644 --- a/src/iceberg/row/struct_like.h +++ b/src/iceberg/row/struct_like.h @@ -55,6 +55,9 @@ using Scalar = std::variant, // for list std::shared_ptr>; // for map +/// \brief Convert a Literal to a Scalar +Result LiteralToScalar(const Literal& literal); + /// \brief An immutable struct-like wrapper. class ICEBERG_EXPORT StructLike { public: diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index 264e606f7..9885c7a6f 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -23,6 +23,7 @@ #include "iceberg/expression/binder.h" #include "iceberg/expression/expressions.h" +#include "iceberg/manifest/manifest_entry.h" #include "iceberg/row/struct_like.h" #include "iceberg/schema.h" #include "iceberg/test/matchers.h" @@ -236,4 +237,243 @@ TEST(AggregateTest, MultipleAggregatesInEvaluator) { EXPECT_EQ(std::get(results[4].value()), 4); // count_star } +TEST(AggregateTest, AggregatesFromDataFileMetrics) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id")); + auto count_star_bound = BindAggregate(schema, Expressions::CountStar()); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + auto min_bound = BindAggregate(schema, Expressions::Min("value")); + + std::vector> aggregates{ + count_bound, count_null_bound, count_star_bound, max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize()); + ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(50).Serialize()); + DataFile file{ + .record_count = 10, + .value_counts = {{1, 10}, {2, 10}}, + .null_value_counts = {{1, 2}, {2, 0}}, + .lower_bounds = {{2, lower}}, + .upper_bounds = {{2, upper}}, + }; + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + EXPECT_EQ(std::get(results[0].value()), 8); // count(id) = 10 - 2 + EXPECT_EQ(std::get(results[1].value()), 2); // count_null(id) + EXPECT_EQ(std::get(results[2].value()), 10); // count_star + EXPECT_EQ(std::get(results[3].value()), 50); // max(value) + EXPECT_EQ(std::get(results[4].value()), 5); // min(value) +} + +TEST(AggregateTest, AggregatesFromDataFileMissingMetricsReturnNull) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("id")); + auto count_star_bound = BindAggregate(schema, Expressions::CountStar()); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + auto min_bound = BindAggregate(schema, Expressions::Min("value")); + + std::vector> aggregates{ + count_bound, count_null_bound, count_star_bound, max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + DataFile file{.record_count = -1}; // missing/invalid + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + for (const auto& literal : results) { + EXPECT_TRUE(literal.IsNull()); + } +} + +TEST(AggregateTest, AggregatesFromDataFileWithTransform) { + Schema schema({SchemaField::MakeOptional(1, "id", int32())}); + + auto truncate_id = Expressions::Truncate("id", 10); + auto max_bound = BindAggregate(schema, Expressions::Max(truncate_id)); + auto min_bound = BindAggregate(schema, Expressions::Min(truncate_id)); + + std::vector> aggregates{max_bound, min_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + ICEBERG_UNWRAP_OR_FAIL(auto lower, Literal::Int(5).Serialize()); + ICEBERG_UNWRAP_OR_FAIL(auto upper, Literal::Int(23).Serialize()); + DataFile file{ + .record_count = 5, + .value_counts = {{1, 5}}, + .null_value_counts = {{1, 0}}, + .lower_bounds = {{1, lower}}, + .upper_bounds = {{1, upper}}, + }; + + ASSERT_TRUE(evaluator->Update(file).has_value()); + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), aggregates.size()); + // Truncate width 10: max(truncate(23)) -> 20, min(truncate(5)) -> 0 + EXPECT_EQ(std::get(results[0].value()), 20); + EXPECT_EQ(std::get(results[1].value()), 0); + EXPECT_TRUE(evaluator->AllAggregatorsValid()); +} + +TEST(AggregateTest, DataFileAggregatorParity) { + Schema schema({SchemaField::MakeRequired(1, "id", int32()), + SchemaField::MakeOptional(2, "no_stats", int32()), + SchemaField::MakeOptional(3, "all_nulls", string()), + SchemaField::MakeOptional(4, "some_nulls", string())}); + + auto make_bounds = [](int field_id, int32_t lower, int32_t upper) { + std::map> lower_bounds; + std::map> upper_bounds; + auto lser = Literal::Int(lower).Serialize().value(); + auto user = Literal::Int(upper).Serialize().value(); + lower_bounds.emplace(field_id, std::move(lser)); + upper_bounds.emplace(field_id, std::move(user)); + return std::pair{std::move(lower_bounds), std::move(upper_bounds)}; + }; + + auto [b1_lower, b1_upper] = make_bounds(1, 33, 2345); + DataFile file{ + .file_path = "file.avro", + .record_count = 50, + .value_counts = {{1, 50}, {3, 50}, {4, 50}}, + .null_value_counts = {{1, 10}, {3, 50}, {4, 10}}, + .lower_bounds = std::move(b1_lower), + .upper_bounds = std::move(b1_upper), + }; + + auto [b2_lower, b2_upper] = make_bounds(1, 33, 100); + DataFile missing_some_nulls_1{ + .file_path = "file_2.avro", + .record_count = 20, + .value_counts = {{1, 20}, {3, 20}}, + .null_value_counts = {{1, 0}, {3, 20}}, + .lower_bounds = std::move(b2_lower), + .upper_bounds = std::move(b2_upper), + }; + + auto [b3_lower, b3_upper] = make_bounds(1, -33, 3333); + DataFile missing_some_nulls_2{ + .file_path = "file_3.avro", + .record_count = 20, + .value_counts = {{1, 20}, {3, 20}}, + .null_value_counts = {{1, 20}, {3, 20}}, + .lower_bounds = std::move(b3_lower), + .upper_bounds = std::move(b3_upper), + }; + + DataFile missing_some_stats{ + .file_path = "file_missing_stats.avro", + .record_count = 20, + .value_counts = {{1, 20}, {4, 10}}, + }; + auto [b4_lower, b4_upper] = make_bounds(1, -3, 1333); + missing_some_stats.lower_bounds = std::move(b4_lower); + missing_some_stats.upper_bounds = std::move(b4_upper); + + DataFile missing_all_optional_stats{ + .file_path = "file_null_stats.avro", + .record_count = 20, + }; + + auto run_case = [&](const std::vector>& exprs, + const std::vector& files, + const std::vector>& expected, + bool expect_all_valid) { + std::vector> aggregates; + aggregates.reserve(exprs.size()); + for (const auto& e : exprs) { + aggregates.emplace_back(BindAggregate(schema, e)); + } + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + for (const auto& f : files) { + ASSERT_TRUE(evaluator->Update(f).has_value()); + } + ASSERT_EQ(evaluator->AllAggregatorsValid(), expect_all_valid); + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), expected.size()); + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].has_value()) { + EXPECT_TRUE(results[i].IsNull()); + } else { + const auto& exp = *expected[i]; + const auto& res = results[i].value(); + if (std::holds_alternative(exp)) { + EXPECT_EQ(std::get(res), std::get(exp)); + } else if (std::holds_alternative(exp)) { + EXPECT_EQ(std::get(res), std::get(exp)); + } else { + FAIL() << "Unexpected expected type"; + } + } + } + }; + + // testIntAggregate + run_case({Expressions::CountStar(), Expressions::Count("id"), + Expressions::CountNull("id"), Expressions::Max("id"), Expressions::Min("id")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, Scalar{int64_t{60}}, Scalar{int64_t{30}}, + Scalar{int32_t{3333}}, Scalar{int32_t{-33}}}, + /*expect_all_valid=*/true); + + // testAllNulls + run_case({Expressions::CountStar(), Expressions::Count("all_nulls"), + Expressions::CountNull("all_nulls"), Expressions::Max("all_nulls"), + Expressions::Min("all_nulls")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, Scalar{int64_t{0}}, Scalar{int64_t{90}}, std::nullopt, + std::nullopt}, + /*expect_all_valid=*/true); + + // testSomeNulls -> missing null counts for field 4 + run_case({Expressions::CountStar(), Expressions::Count("some_nulls"), + Expressions::CountNull("some_nulls"), Expressions::Max("some_nulls"), + Expressions::Min("some_nulls")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testNoStats -> field 2 has no metrics + run_case({Expressions::CountStar(), Expressions::Count("no_stats"), + Expressions::CountNull("no_stats"), Expressions::Max("no_stats"), + Expressions::Min("no_stats")}, + {file, missing_some_nulls_1, missing_some_nulls_2}, + {Scalar{int64_t{90}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testIntAggregateAllMissingStats -> id missing optional stats + run_case({Expressions::CountStar(), Expressions::Count("id"), + Expressions::CountNull("id"), Expressions::Max("id"), Expressions::Min("id")}, + {missing_all_optional_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testOptionalColAllMissingStats -> field 2 missing everything + run_case({Expressions::CountStar(), Expressions::Count("no_stats"), + Expressions::CountNull("no_stats"), Expressions::Max("no_stats"), + Expressions::Min("no_stats")}, + {missing_all_optional_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); + + // testMissingSomeStats -> some_nulls missing null stats entirely + run_case({Expressions::CountStar(), Expressions::Count("some_nulls"), + Expressions::Max("some_nulls"), Expressions::Min("some_nulls")}, + {missing_some_stats}, + {Scalar{int64_t{20}}, std::nullopt, std::nullopt, std::nullopt}, + /*expect_all_valid=*/false); +} + } // namespace iceberg