Skip to content

Commit 282b708

Browse files
committed
feat: add DataFile aggregate evaluation
1 parent 3fc0445 commit 282b708

File tree

7 files changed

+576
-7
lines changed

7 files changed

+576
-7
lines changed

src/iceberg/expression/aggregate.cc

Lines changed: 218 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@
1919

2020
#include "iceberg/expression/aggregate.h"
2121

22+
#include <algorithm>
2223
#include <format>
24+
#include <map>
2325
#include <optional>
26+
#include <string_view>
2427
#include <vector>
2528

2629
#include "iceberg/expression/literal.h"
30+
#include "iceberg/manifest/manifest_entry.h"
2731
#include "iceberg/row/struct_like.h"
2832
#include "iceberg/type.h"
2933
#include "iceberg/util/checked_cast.h"
@@ -38,6 +42,19 @@ std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) {
3842
return internal::checked_pointer_cast<PrimitiveType>(term.type());
3943
}
4044

45+
Result<Literal> EvaluateBoundTerm(const BoundTerm& term,
46+
const std::optional<std::vector<uint8_t>>& bound) {
47+
auto ptype = GetPrimitiveType(term);
48+
if (!bound.has_value()) {
49+
SingleValueStructLike data(Literal::Null(ptype));
50+
return term.Evaluate(data);
51+
}
52+
53+
ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype));
54+
SingleValueStructLike data(std::move(literal));
55+
return term.Evaluate(data);
56+
}
57+
4158
class CountAggregator : public BoundAggregate::Aggregator {
4259
public:
4360
explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {}
@@ -48,11 +65,32 @@ class CountAggregator : public BoundAggregate::Aggregator {
4865
return {};
4966
}
5067

51-
Literal GetResult() const override { return Literal::Long(count_); }
68+
Status Update(const DataFile& file) override {
69+
if (!valid_) {
70+
return {};
71+
}
72+
if (!aggregate_.HasValue(file)) {
73+
valid_ = false;
74+
return {};
75+
}
76+
ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file));
77+
count_ += count;
78+
return {};
79+
}
80+
81+
Literal GetResult() const override {
82+
if (!valid_) {
83+
return Literal::Null(int64());
84+
}
85+
return Literal::Long(count_);
86+
}
87+
88+
bool IsValid() const override { return valid_; }
5289

5390
private:
5491
const CountAggregate& aggregate_;
5592
int64_t count_ = 0;
93+
bool valid_ = true;
5694
};
5795

5896
class MaxAggregator : public BoundAggregate::Aggregator {
@@ -82,11 +120,47 @@ class MaxAggregator : public BoundAggregate::Aggregator {
82120
return {};
83121
}
84122

85-
Literal GetResult() const override { return current_; }
123+
Status Update(const DataFile& file) override {
124+
if (!valid_) {
125+
return {};
126+
}
127+
if (!aggregate_.HasValue(file)) {
128+
valid_ = false;
129+
return {};
130+
}
131+
132+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
133+
if (value.IsNull()) {
134+
return {};
135+
}
136+
if (current_.IsNull()) {
137+
current_ = std::move(value);
138+
return {};
139+
}
140+
141+
if (auto ordering = value <=> current_;
142+
ordering == std::partial_ordering::unordered) {
143+
return InvalidArgument("Cannot compare literal {} with current value {}",
144+
value.ToString(), current_.ToString());
145+
} else if (ordering == std::partial_ordering::greater) {
146+
current_ = std::move(value);
147+
}
148+
return {};
149+
}
150+
151+
Literal GetResult() const override {
152+
if (!valid_) {
153+
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
154+
}
155+
return current_;
156+
}
157+
158+
bool IsValid() const override { return valid_; }
86159

87160
private:
88161
const MaxAggregate& aggregate_;
89162
Literal current_;
163+
bool valid_ = true;
90164
};
91165

92166
class MinAggregator : public BoundAggregate::Aggregator {
@@ -115,13 +189,65 @@ class MinAggregator : public BoundAggregate::Aggregator {
115189
return {};
116190
}
117191

118-
Literal GetResult() const override { return current_; }
192+
Status Update(const DataFile& file) override {
193+
if (!valid_) {
194+
return {};
195+
}
196+
if (!aggregate_.HasValue(file)) {
197+
valid_ = false;
198+
return {};
199+
}
200+
201+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
202+
if (value.IsNull()) {
203+
return {};
204+
}
205+
if (current_.IsNull()) {
206+
current_ = std::move(value);
207+
return {};
208+
}
209+
210+
if (auto ordering = value <=> current_;
211+
ordering == std::partial_ordering::unordered) {
212+
return InvalidArgument("Cannot compare literal {} with current value {}",
213+
value.ToString(), current_.ToString());
214+
} else if (ordering == std::partial_ordering::less) {
215+
current_ = std::move(value);
216+
}
217+
return {};
218+
}
219+
220+
Literal GetResult() const override {
221+
if (!valid_) {
222+
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
223+
}
224+
return current_;
225+
}
226+
227+
bool IsValid() const override { return valid_; }
119228

120229
private:
121230
const MinAggregate& aggregate_;
122231
Literal current_;
232+
bool valid_ = true;
123233
};
124234

235+
template <typename T>
236+
std::optional<T> GetMapValue(const std::map<int32_t, T>& map, int32_t key) {
237+
auto iter = map.find(key);
238+
if (iter == map.end()) {
239+
return std::nullopt;
240+
}
241+
return iter->second;
242+
}
243+
244+
int32_t GetFieldId(const std::shared_ptr<BoundTerm>& term) {
245+
ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null");
246+
auto ref = term->reference();
247+
ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be null");
248+
return ref->field().field_id();
249+
}
250+
125251
} // namespace
126252

127253
template <TermType T>
@@ -149,7 +275,11 @@ std::string Aggregate<T>::ToString() const {
149275
// -------------------- CountAggregate --------------------
150276

151277
Result<Literal> CountAggregate::Evaluate(const StructLike& data) const {
152-
return CountFor(data).transform([](int64_t count) { return Literal::Long(count); });
278+
return CountFor(data).transform(Literal::Long);
279+
}
280+
281+
Result<Literal> CountAggregate::Evaluate(const DataFile& file) const {
282+
return CountFor(file).transform(Literal::Long);
153283
}
154284

155285
std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const {
@@ -173,6 +303,22 @@ Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const {
173303
[](const auto& val) { return val.IsNull() ? 0 : 1; });
174304
}
175305

306+
Result<int64_t> CountNonNullAggregate::CountFor(const DataFile& file) const {
307+
auto field_id = GetFieldId(term());
308+
if (!HasValue(file)) {
309+
return NotFound("Missing metrics for field id {}", field_id);
310+
}
311+
auto value_count = GetMapValue(file.value_counts, field_id).value();
312+
auto null_count = GetMapValue(file.null_value_counts, field_id).value();
313+
return value_count - null_count;
314+
}
315+
316+
bool CountNonNullAggregate::HasValue(const DataFile& file) const {
317+
auto field_id = GetFieldId(term());
318+
return file.value_counts.contains(field_id) &&
319+
file.null_value_counts.contains(field_id);
320+
}
321+
176322
CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term)
177323
: CountAggregate(Expression::Operation::kCountNull, std::move(term)) {}
178324

@@ -189,6 +335,18 @@ Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const {
189335
[](const auto& val) { return val.IsNull() ? 1 : 0; });
190336
}
191337

338+
Result<int64_t> CountNullAggregate::CountFor(const DataFile& file) const {
339+
auto field_id = GetFieldId(term());
340+
if (!HasValue(file)) {
341+
return NotFound("Missing metrics for field id {}", field_id);
342+
}
343+
return GetMapValue(file.null_value_counts, field_id).value();
344+
}
345+
346+
bool CountNullAggregate::HasValue(const DataFile& file) const {
347+
return file.null_value_counts.contains(GetFieldId(term()));
348+
}
349+
192350
CountStarAggregate::CountStarAggregate()
193351
: CountAggregate(Expression::Operation::kCountStar, nullptr) {}
194352

@@ -200,6 +358,17 @@ Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const {
200358
return 1;
201359
}
202360

361+
Result<int64_t> CountStarAggregate::CountFor(const DataFile& file) const {
362+
if (!HasValue(file)) {
363+
return NotFound("Record count is missing");
364+
}
365+
return file.record_count;
366+
}
367+
368+
bool CountStarAggregate::HasValue(const DataFile& file) const {
369+
return file.record_count >= 0;
370+
}
371+
203372
MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term)
204373
: BoundAggregate(Expression::Operation::kMax, std::move(term)) {}
205374

@@ -211,10 +380,26 @@ Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
211380
return term()->Evaluate(data);
212381
}
213382

383+
Result<Literal> MaxAggregate::Evaluate(const DataFile& file) const {
384+
auto field_id = GetFieldId(term());
385+
auto upper = GetMapValue(file.upper_bounds, field_id);
386+
return EvaluateBoundTerm(*term(), upper);
387+
}
388+
214389
std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const {
215390
return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this));
216391
}
217392

393+
bool MaxAggregate::HasValue(const DataFile& file) const {
394+
auto field_id = GetFieldId(term());
395+
bool has_bound = file.upper_bounds.contains(field_id);
396+
auto value_count = GetMapValue(file.value_counts, field_id);
397+
auto null_count = GetMapValue(file.null_value_counts, field_id);
398+
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
399+
null_count.value() == value_count.value();
400+
return has_bound || all_null;
401+
}
402+
218403
MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term)
219404
: BoundAggregate(Expression::Operation::kMin, std::move(term)) {}
220405

@@ -226,10 +411,26 @@ Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
226411
return term()->Evaluate(data);
227412
}
228413

414+
Result<Literal> MinAggregate::Evaluate(const DataFile& file) const {
415+
auto field_id = GetFieldId(term());
416+
auto lower = GetMapValue(file.lower_bounds, field_id);
417+
return EvaluateBoundTerm(*term(), lower);
418+
}
419+
229420
std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const {
230421
return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this));
231422
}
232423

424+
bool MinAggregate::HasValue(const DataFile& file) const {
425+
auto field_id = GetFieldId(term());
426+
bool has_bound = file.lower_bounds.contains(field_id);
427+
auto value_count = GetMapValue(file.value_counts, field_id);
428+
auto null_count = GetMapValue(file.null_value_counts, field_id);
429+
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
430+
null_count.value() == value_count.value();
431+
return has_bound || all_null;
432+
}
433+
233434
// -------------------- Unbound binding --------------------
234435

235436
template <typename B>
@@ -275,8 +476,10 @@ Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make(
275476
}
276477

277478
template class Aggregate<UnboundTerm<BoundReference>>;
479+
template class Aggregate<UnboundTerm<BoundTransform>>;
278480
template class Aggregate<BoundTerm>;
279481
template class UnboundAggregateImpl<BoundReference>;
482+
template class UnboundAggregateImpl<BoundTransform>;
280483

281484
// -------------------- AggregateEvaluator --------------------
282485

@@ -296,6 +499,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
296499
return {};
297500
}
298501

502+
Status Update(const DataFile& file) override {
503+
for (auto& aggregator : aggregators_) {
504+
ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file));
505+
}
506+
return {};
507+
}
508+
299509
Result<std::span<const Literal>> GetResults() const override {
300510
results_.clear();
301511
results_.reserve(aggregates_.size());
@@ -315,6 +525,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
315525
return all.front();
316526
}
317527

528+
bool AllAggregatorsValid() const override {
529+
return std::ranges::all_of(aggregators_, &BoundAggregate::Aggregator::IsValid);
530+
}
531+
318532
private:
319533
std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
320534
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;

0 commit comments

Comments
 (0)