1919
2020#include " iceberg/expression/aggregate.h"
2121
22+ #include < algorithm>
2223#include < format>
24+ #include < map>
2325#include < optional>
26+ #include < string_view>
2427#include < vector>
2528
2629#include " iceberg/expression/literal.h"
30+ #include " iceberg/manifest/manifest_entry.h"
2731#include " iceberg/row/struct_like.h"
2832#include " iceberg/type.h"
2933#include " iceberg/util/checked_cast.h"
@@ -38,6 +42,19 @@ std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) {
3842 return internal::checked_pointer_cast<PrimitiveType>(term.type ());
3943}
4044
45+ Result<Literal> EvaluateBoundTerm (const BoundTerm& term,
46+ const std::optional<std::vector<uint8_t >>& bound) {
47+ auto ptype = GetPrimitiveType (term);
48+ if (!bound.has_value ()) {
49+ SingleValueStructLike data (Literal::Null (ptype));
50+ return term.Evaluate (data);
51+ }
52+
53+ ICEBERG_ASSIGN_OR_RAISE (auto literal, Literal::Deserialize (*bound, ptype));
54+ SingleValueStructLike data (std::move (literal));
55+ return term.Evaluate (data);
56+ }
57+
4158class CountAggregator : public BoundAggregate ::Aggregator {
4259 public:
4360 explicit CountAggregator (const CountAggregate& aggregate) : aggregate_(aggregate) {}
@@ -48,11 +65,32 @@ class CountAggregator : public BoundAggregate::Aggregator {
4865 return {};
4966 }
5067
51- Literal GetResult () const override { return Literal::Long (count_); }
68+ Status Update (const DataFile& file) override {
69+ if (!valid_) {
70+ return {};
71+ }
72+ if (!aggregate_.HasValue (file)) {
73+ valid_ = false ;
74+ return {};
75+ }
76+ ICEBERG_ASSIGN_OR_RAISE (auto count, aggregate_.CountFor (file));
77+ count_ += count;
78+ return {};
79+ }
80+
81+ Literal GetResult () const override {
82+ if (!valid_) {
83+ return Literal::Null (int64 ());
84+ }
85+ return Literal::Long (count_);
86+ }
87+
88+ bool IsValid () const override { return valid_; }
5289
5390 private:
5491 const CountAggregate& aggregate_;
5592 int64_t count_ = 0 ;
93+ bool valid_ = true ;
5694};
5795
5896class MaxAggregator : public BoundAggregate ::Aggregator {
@@ -82,11 +120,47 @@ class MaxAggregator : public BoundAggregate::Aggregator {
82120 return {};
83121 }
84122
85- Literal GetResult () const override { return current_; }
123+ Status Update (const DataFile& file) override {
124+ if (!valid_) {
125+ return {};
126+ }
127+ if (!aggregate_.HasValue (file)) {
128+ valid_ = false ;
129+ return {};
130+ }
131+
132+ ICEBERG_ASSIGN_OR_RAISE (auto value, aggregate_.Evaluate (file));
133+ if (value.IsNull ()) {
134+ return {};
135+ }
136+ if (current_.IsNull ()) {
137+ current_ = std::move (value);
138+ return {};
139+ }
140+
141+ if (auto ordering = value <=> current_;
142+ ordering == std::partial_ordering::unordered) {
143+ return InvalidArgument (" Cannot compare literal {} with current value {}" ,
144+ value.ToString (), current_.ToString ());
145+ } else if (ordering == std::partial_ordering::greater) {
146+ current_ = std::move (value);
147+ }
148+ return {};
149+ }
150+
151+ Literal GetResult () const override {
152+ if (!valid_) {
153+ return Literal::Null (GetPrimitiveType (*aggregate_.term ()));
154+ }
155+ return current_;
156+ }
157+
158+ bool IsValid () const override { return valid_; }
86159
87160 private:
88161 const MaxAggregate& aggregate_;
89162 Literal current_;
163+ bool valid_ = true ;
90164};
91165
92166class MinAggregator : public BoundAggregate ::Aggregator {
@@ -115,13 +189,65 @@ class MinAggregator : public BoundAggregate::Aggregator {
115189 return {};
116190 }
117191
118- Literal GetResult () const override { return current_; }
192+ Status Update (const DataFile& file) override {
193+ if (!valid_) {
194+ return {};
195+ }
196+ if (!aggregate_.HasValue (file)) {
197+ valid_ = false ;
198+ return {};
199+ }
200+
201+ ICEBERG_ASSIGN_OR_RAISE (auto value, aggregate_.Evaluate (file));
202+ if (value.IsNull ()) {
203+ return {};
204+ }
205+ if (current_.IsNull ()) {
206+ current_ = std::move (value);
207+ return {};
208+ }
209+
210+ if (auto ordering = value <=> current_;
211+ ordering == std::partial_ordering::unordered) {
212+ return InvalidArgument (" Cannot compare literal {} with current value {}" ,
213+ value.ToString (), current_.ToString ());
214+ } else if (ordering == std::partial_ordering::less) {
215+ current_ = std::move (value);
216+ }
217+ return {};
218+ }
219+
220+ Literal GetResult () const override {
221+ if (!valid_) {
222+ return Literal::Null (GetPrimitiveType (*aggregate_.term ()));
223+ }
224+ return current_;
225+ }
226+
227+ bool IsValid () const override { return valid_; }
119228
120229 private:
121230 const MinAggregate& aggregate_;
122231 Literal current_;
232+ bool valid_ = true ;
123233};
124234
235+ template <typename T>
236+ std::optional<T> GetMapValue (const std::map<int32_t , T>& map, int32_t key) {
237+ auto iter = map.find (key);
238+ if (iter == map.end ()) {
239+ return std::nullopt ;
240+ }
241+ return iter->second ;
242+ }
243+
244+ int32_t GetFieldId (const std::shared_ptr<BoundTerm>& term) {
245+ ICEBERG_DCHECK (term != nullptr , " Aggregate term should not be null" );
246+ auto ref = term->reference ();
247+ ICEBERG_DCHECK (ref != nullptr , " Aggregate term reference should not be null" );
248+ return ref->field ().field_id ();
249+ }
250+
125251} // namespace
126252
127253template <TermType T>
@@ -149,7 +275,11 @@ std::string Aggregate<T>::ToString() const {
149275// -------------------- CountAggregate --------------------
150276
151277Result<Literal> CountAggregate::Evaluate (const StructLike& data) const {
152- return CountFor (data).transform ([](int64_t count) { return Literal::Long (count); });
278+ return CountFor (data).transform (Literal::Long);
279+ }
280+
281+ Result<Literal> CountAggregate::Evaluate (const DataFile& file) const {
282+ return CountFor (file).transform (Literal::Long);
153283}
154284
155285std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator () const {
@@ -173,6 +303,22 @@ Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const {
173303 [](const auto & val) { return val.IsNull () ? 0 : 1 ; });
174304}
175305
306+ Result<int64_t > CountNonNullAggregate::CountFor (const DataFile& file) const {
307+ auto field_id = GetFieldId (term ());
308+ if (!HasValue (file)) {
309+ return NotFound (" Missing metrics for field id {}" , field_id);
310+ }
311+ auto value_count = GetMapValue (file.value_counts , field_id).value ();
312+ auto null_count = GetMapValue (file.null_value_counts , field_id).value ();
313+ return value_count - null_count;
314+ }
315+
316+ bool CountNonNullAggregate::HasValue (const DataFile& file) const {
317+ auto field_id = GetFieldId (term ());
318+ return file.value_counts .contains (field_id) &&
319+ file.null_value_counts .contains (field_id);
320+ }
321+
176322CountNullAggregate::CountNullAggregate (std::shared_ptr<BoundTerm> term)
177323 : CountAggregate(Expression::Operation::kCountNull , std::move(term)) {}
178324
@@ -189,6 +335,18 @@ Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const {
189335 [](const auto & val) { return val.IsNull () ? 1 : 0 ; });
190336}
191337
338+ Result<int64_t > CountNullAggregate::CountFor (const DataFile& file) const {
339+ auto field_id = GetFieldId (term ());
340+ if (!HasValue (file)) {
341+ return NotFound (" Missing metrics for field id {}" , field_id);
342+ }
343+ return GetMapValue (file.null_value_counts , field_id).value ();
344+ }
345+
346+ bool CountNullAggregate::HasValue (const DataFile& file) const {
347+ return file.null_value_counts .contains (GetFieldId (term ()));
348+ }
349+
192350CountStarAggregate::CountStarAggregate ()
193351 : CountAggregate(Expression::Operation::kCountStar , nullptr ) {}
194352
@@ -200,6 +358,17 @@ Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const {
200358 return 1 ;
201359}
202360
361+ Result<int64_t > CountStarAggregate::CountFor (const DataFile& file) const {
362+ if (!HasValue (file)) {
363+ return NotFound (" Record count is missing" );
364+ }
365+ return file.record_count ;
366+ }
367+
368+ bool CountStarAggregate::HasValue (const DataFile& file) const {
369+ return file.record_count >= 0 ;
370+ }
371+
203372MaxAggregate::MaxAggregate (std::shared_ptr<BoundTerm> term)
204373 : BoundAggregate(Expression::Operation::kMax , std::move(term)) {}
205374
@@ -211,10 +380,26 @@ Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
211380 return term ()->Evaluate (data);
212381}
213382
383+ Result<Literal> MaxAggregate::Evaluate (const DataFile& file) const {
384+ auto field_id = GetFieldId (term ());
385+ auto upper = GetMapValue (file.upper_bounds , field_id);
386+ return EvaluateBoundTerm (*term (), upper);
387+ }
388+
214389std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator () const {
215390 return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator (*this ));
216391}
217392
393+ bool MaxAggregate::HasValue (const DataFile& file) const {
394+ auto field_id = GetFieldId (term ());
395+ bool has_bound = file.upper_bounds .contains (field_id);
396+ auto value_count = GetMapValue (file.value_counts , field_id);
397+ auto null_count = GetMapValue (file.null_value_counts , field_id);
398+ bool all_null = value_count.has_value () && *value_count > 0 && null_count.has_value () &&
399+ null_count.value () == value_count.value ();
400+ return has_bound || all_null;
401+ }
402+
218403MinAggregate::MinAggregate (std::shared_ptr<BoundTerm> term)
219404 : BoundAggregate(Expression::Operation::kMin , std::move(term)) {}
220405
@@ -226,10 +411,26 @@ Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
226411 return term ()->Evaluate (data);
227412}
228413
414+ Result<Literal> MinAggregate::Evaluate (const DataFile& file) const {
415+ auto field_id = GetFieldId (term ());
416+ auto lower = GetMapValue (file.lower_bounds , field_id);
417+ return EvaluateBoundTerm (*term (), lower);
418+ }
419+
229420std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator () const {
230421 return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator (*this ));
231422}
232423
424+ bool MinAggregate::HasValue (const DataFile& file) const {
425+ auto field_id = GetFieldId (term ());
426+ bool has_bound = file.lower_bounds .contains (field_id);
427+ auto value_count = GetMapValue (file.value_counts , field_id);
428+ auto null_count = GetMapValue (file.null_value_counts , field_id);
429+ bool all_null = value_count.has_value () && *value_count > 0 && null_count.has_value () &&
430+ null_count.value () == value_count.value ();
431+ return has_bound || all_null;
432+ }
433+
233434// -------------------- Unbound binding --------------------
234435
235436template <typename B>
@@ -275,8 +476,10 @@ Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make(
275476}
276477
277478template class Aggregate <UnboundTerm<BoundReference>>;
479+ template class Aggregate <UnboundTerm<BoundTransform>>;
278480template class Aggregate <BoundTerm>;
279481template class UnboundAggregateImpl <BoundReference>;
482+ template class UnboundAggregateImpl <BoundTransform>;
280483
281484// -------------------- AggregateEvaluator --------------------
282485
@@ -296,6 +499,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
296499 return {};
297500 }
298501
502+ Status Update (const DataFile& file) override {
503+ for (auto & aggregator : aggregators_) {
504+ ICEBERG_RETURN_UNEXPECTED (aggregator->Update (file));
505+ }
506+ return {};
507+ }
508+
299509 Result<std::span<const Literal>> GetResults () const override {
300510 results_.clear ();
301511 results_.reserve (aggregates_.size ());
@@ -315,6 +525,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
315525 return all.front ();
316526 }
317527
528+ bool AllAggregatorsValid () const override {
529+ return std::ranges::all_of (aggregators_, &BoundAggregate::Aggregator::IsValid);
530+ }
531+
318532 private:
319533 std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
320534 std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;
0 commit comments