Skip to content

Commit 6fd067f

Browse files
committed
feat: add DataFile aggregate evaluation
1 parent 09f26b6 commit 6fd067f

File tree

10 files changed

+589
-9
lines changed

10 files changed

+589
-9
lines changed

src/iceberg/expression/aggregate.cc

Lines changed: 230 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
#include "iceberg/expression/aggregate.h"
2121

2222
#include <format>
23+
#include <map>
2324
#include <optional>
25+
#include <string_view>
2426
#include <vector>
2527

2628
#include "iceberg/expression/literal.h"
29+
#include "iceberg/manifest/manifest_entry.h"
2730
#include "iceberg/row/struct_like.h"
2831
#include "iceberg/type.h"
2932
#include "iceberg/util/checked_cast.h"
@@ -38,6 +41,19 @@ std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) {
3841
return internal::checked_pointer_cast<PrimitiveType>(term.type());
3942
}
4043

44+
Result<Literal> EvaluateBoundTerm(const BoundTerm& term,
45+
const std::optional<std::vector<uint8_t>>& bound) {
46+
auto ptype = GetPrimitiveType(term);
47+
if (!bound.has_value()) {
48+
SingleValueStructLike data(Literal::Null(ptype));
49+
return term.Evaluate(data);
50+
}
51+
52+
ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype));
53+
SingleValueStructLike data(std::move(literal));
54+
return term.Evaluate(data);
55+
}
56+
4157
class CountAggregator : public BoundAggregate::Aggregator {
4258
public:
4359
explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {}
@@ -48,11 +64,32 @@ class CountAggregator : public BoundAggregate::Aggregator {
4864
return {};
4965
}
5066

51-
Literal GetResult() const override { return Literal::Long(count_); }
67+
Status Update(const DataFile& file) override {
68+
if (!valid_) {
69+
return {};
70+
}
71+
if (!aggregate_.HasValue(file)) {
72+
valid_ = false;
73+
return {};
74+
}
75+
ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file));
76+
count_ += count;
77+
return {};
78+
}
79+
80+
Literal GetResult() const override {
81+
if (!valid_) {
82+
return Literal::Null(int64());
83+
}
84+
return Literal::Long(count_);
85+
}
86+
87+
bool IsValid() const override { return valid_; }
5288

5389
private:
5490
const CountAggregate& aggregate_;
5591
int64_t count_ = 0;
92+
bool valid_ = true;
5693
};
5794

5895
class MaxAggregator : public BoundAggregate::Aggregator {
@@ -82,11 +119,47 @@ class MaxAggregator : public BoundAggregate::Aggregator {
82119
return {};
83120
}
84121

85-
Literal GetResult() const override { return current_; }
122+
Status Update(const DataFile& file) override {
123+
if (!valid_) {
124+
return {};
125+
}
126+
if (!aggregate_.HasValue(file)) {
127+
valid_ = false;
128+
return {};
129+
}
130+
131+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
132+
if (value.IsNull()) {
133+
return {};
134+
}
135+
if (current_.IsNull()) {
136+
current_ = std::move(value);
137+
return {};
138+
}
139+
140+
if (auto ordering = value <=> current_;
141+
ordering == std::partial_ordering::unordered) {
142+
return InvalidArgument("Cannot compare literal {} with current value {}",
143+
value.ToString(), current_.ToString());
144+
} else if (ordering == std::partial_ordering::greater) {
145+
current_ = std::move(value);
146+
}
147+
return {};
148+
}
149+
150+
Literal GetResult() const override {
151+
if (!valid_) {
152+
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
153+
}
154+
return current_;
155+
}
156+
157+
bool IsValid() const override { return valid_; }
86158

87159
private:
88160
const MaxAggregate& aggregate_;
89161
Literal current_;
162+
bool valid_ = true;
90163
};
91164

92165
class MinAggregator : public BoundAggregate::Aggregator {
@@ -115,13 +188,73 @@ class MinAggregator : public BoundAggregate::Aggregator {
115188
return {};
116189
}
117190

118-
Literal GetResult() const override { return current_; }
191+
Status Update(const DataFile& file) override {
192+
if (!valid_) {
193+
return {};
194+
}
195+
if (!aggregate_.HasValue(file)) {
196+
valid_ = false;
197+
return {};
198+
}
199+
200+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
201+
if (value.IsNull()) {
202+
return {};
203+
}
204+
if (current_.IsNull()) {
205+
current_ = std::move(value);
206+
return {};
207+
}
208+
209+
if (auto ordering = value <=> current_;
210+
ordering == std::partial_ordering::unordered) {
211+
return InvalidArgument("Cannot compare literal {} with current value {}",
212+
value.ToString(), current_.ToString());
213+
} else if (ordering == std::partial_ordering::less) {
214+
current_ = std::move(value);
215+
}
216+
return {};
217+
}
218+
219+
Literal GetResult() const override {
220+
if (!valid_) {
221+
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
222+
}
223+
return current_;
224+
}
225+
226+
bool IsValid() const override { return valid_; }
119227

120228
private:
121229
const MinAggregate& aggregate_;
122230
Literal current_;
231+
bool valid_ = true;
123232
};
124233

234+
bool HasMapKey(const std::map<int32_t, int64_t>& map, int32_t key) {
235+
return map.contains(key);
236+
}
237+
238+
bool HasMapKey(const std::map<int32_t, std::vector<uint8_t>>& map, int32_t key) {
239+
return map.contains(key);
240+
}
241+
242+
template <typename T>
243+
std::optional<T> GetMapValue(const std::map<int32_t, T>& map, int32_t key) {
244+
auto iter = map.find(key);
245+
if (iter == map.end()) {
246+
return std::nullopt;
247+
}
248+
return iter->second;
249+
}
250+
251+
int32_t GetFieldId(const std::shared_ptr<BoundTerm>& term) {
252+
ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null");
253+
auto ref = term->reference();
254+
ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be null");
255+
return ref->field().field_id();
256+
}
257+
125258
} // namespace
126259

127260
template <TermType T>
@@ -152,6 +285,11 @@ Result<Literal> CountAggregate::Evaluate(const StructLike& data) const {
152285
return CountFor(data).transform([](int64_t count) { return Literal::Long(count); });
153286
}
154287

288+
Result<Literal> CountAggregate::Evaluate(const DataFile& file) const {
289+
ICEBERG_ASSIGN_OR_RAISE(auto count, CountFor(file));
290+
return Literal::Long(count);
291+
}
292+
155293
std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const {
156294
return std::unique_ptr<BoundAggregate::Aggregator>(new CountAggregator(*this));
157295
}
@@ -173,6 +311,22 @@ Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const {
173311
[](const auto& val) { return val.IsNull() ? 0 : 1; });
174312
}
175313

314+
Result<int64_t> CountNonNullAggregate::CountFor(const DataFile& file) const {
315+
auto field_id = GetFieldId(term());
316+
if (!HasValue(file)) {
317+
return NotFound("Missing metrics for field id {}", field_id);
318+
}
319+
auto value_count = GetMapValue(file.value_counts, field_id).value();
320+
auto null_count = GetMapValue(file.null_value_counts, field_id).value();
321+
return value_count - null_count;
322+
}
323+
324+
bool CountNonNullAggregate::HasValue(const DataFile& file) const {
325+
auto field_id = GetFieldId(term());
326+
return HasMapKey(file.value_counts, field_id) &&
327+
HasMapKey(file.null_value_counts, field_id);
328+
}
329+
176330
CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term)
177331
: CountAggregate(Expression::Operation::kCountNull, std::move(term)) {}
178332

@@ -189,6 +343,18 @@ Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const {
189343
[](const auto& val) { return val.IsNull() ? 1 : 0; });
190344
}
191345

346+
Result<int64_t> CountNullAggregate::CountFor(const DataFile& file) const {
347+
auto field_id = GetFieldId(term());
348+
if (!HasValue(file)) {
349+
return NotFound("Missing metrics for field id {}", field_id);
350+
}
351+
return GetMapValue(file.null_value_counts, field_id).value();
352+
}
353+
354+
bool CountNullAggregate::HasValue(const DataFile& file) const {
355+
return HasMapKey(file.null_value_counts, GetFieldId(term()));
356+
}
357+
192358
CountStarAggregate::CountStarAggregate()
193359
: CountAggregate(Expression::Operation::kCountStar, nullptr) {}
194360

@@ -200,6 +366,17 @@ Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const {
200366
return 1;
201367
}
202368

369+
Result<int64_t> CountStarAggregate::CountFor(const DataFile& file) const {
370+
if (!HasValue(file)) {
371+
return NotFound("Record count is missing");
372+
}
373+
return file.record_count;
374+
}
375+
376+
bool CountStarAggregate::HasValue(const DataFile& file) const {
377+
return file.record_count >= 0;
378+
}
379+
203380
MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term)
204381
: BoundAggregate(Expression::Operation::kMax, std::move(term)) {}
205382

@@ -211,10 +388,26 @@ Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
211388
return term()->Evaluate(data);
212389
}
213390

391+
Result<Literal> MaxAggregate::Evaluate(const DataFile& file) const {
392+
auto field_id = GetFieldId(term());
393+
auto upper = GetMapValue(file.upper_bounds, field_id);
394+
return EvaluateBoundTerm(*term(), upper);
395+
}
396+
214397
std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const {
215398
return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this));
216399
}
217400

401+
bool MaxAggregate::HasValue(const DataFile& file) const {
402+
auto field_id = GetFieldId(term());
403+
bool has_bound = HasMapKey(file.upper_bounds, field_id);
404+
auto value_count = GetMapValue(file.value_counts, field_id);
405+
auto null_count = GetMapValue(file.null_value_counts, field_id);
406+
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
407+
null_count.value() == value_count.value();
408+
return has_bound || all_null;
409+
}
410+
218411
MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term)
219412
: BoundAggregate(Expression::Operation::kMin, std::move(term)) {}
220413

@@ -226,10 +419,26 @@ Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
226419
return term()->Evaluate(data);
227420
}
228421

422+
Result<Literal> MinAggregate::Evaluate(const DataFile& file) const {
423+
auto field_id = GetFieldId(term());
424+
auto lower = GetMapValue(file.lower_bounds, field_id);
425+
return EvaluateBoundTerm(*term(), lower);
426+
}
427+
229428
std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const {
230429
return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this));
231430
}
232431

432+
bool MinAggregate::HasValue(const DataFile& file) const {
433+
auto field_id = GetFieldId(term());
434+
bool has_bound = HasMapKey(file.lower_bounds, field_id);
435+
auto value_count = GetMapValue(file.value_counts, field_id);
436+
auto null_count = GetMapValue(file.null_value_counts, field_id);
437+
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
438+
null_count.value() == value_count.value();
439+
return has_bound || all_null;
440+
}
441+
233442
// -------------------- Unbound binding --------------------
234443

235444
template <typename B>
@@ -275,8 +484,10 @@ Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make(
275484
}
276485

277486
template class Aggregate<UnboundTerm<BoundReference>>;
487+
template class Aggregate<UnboundTerm<BoundTransform>>;
278488
template class Aggregate<BoundTerm>;
279489
template class UnboundAggregateImpl<BoundReference>;
490+
template class UnboundAggregateImpl<BoundTransform>;
280491

281492
// -------------------- AggregateEvaluator --------------------
282493

@@ -296,6 +507,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
296507
return {};
297508
}
298509

510+
Status Update(const DataFile& file) override {
511+
for (auto& aggregator : aggregators_) {
512+
ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file));
513+
}
514+
return {};
515+
}
516+
299517
Result<std::span<const Literal>> GetResults() const override {
300518
results_.clear();
301519
results_.reserve(aggregates_.size());
@@ -315,6 +533,15 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
315533
return all.front();
316534
}
317535

536+
bool AllAggregatorsValid() const override {
537+
for (const auto& aggregator : aggregators_) {
538+
if (!aggregator->IsValid()) {
539+
return false;
540+
}
541+
}
542+
return true;
543+
}
544+
318545
private:
319546
std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
320547
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;

0 commit comments

Comments
 (0)