Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 235 additions & 4 deletions src/iceberg/expression/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

#include "iceberg/expression/aggregate.h"

#include <algorithm>
#include <format>
#include <map>
#include <optional>
#include <string_view>
#include <vector>

#include "iceberg/expression/literal.h"
#include "iceberg/manifest/manifest_entry.h"
#include "iceberg/row/struct_like.h"
#include "iceberg/type.h"
#include "iceberg/util/checked_cast.h"
Expand All @@ -38,6 +42,32 @@ std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) {
return internal::checked_pointer_cast<PrimitiveType>(term.type());
}

/// \brief A single-field StructLike that wraps a Literal
class SingleValueStructLike : public StructLike {
public:
explicit SingleValueStructLike(Literal literal) : literal_(std::move(literal)) {}

Result<Scalar> GetField(size_t) const override { return LiteralToScalar(literal_); }

size_t num_fields() const override { return 1; }

private:
Literal literal_;
};

Result<Literal> EvaluateBoundTerm(const BoundTerm& term,
const std::optional<std::vector<uint8_t>>& bound) {
auto ptype = GetPrimitiveType(term);
if (!bound.has_value()) {
SingleValueStructLike data(Literal::Null(ptype));
return term.Evaluate(data);
}

ICEBERG_ASSIGN_OR_RAISE(auto literal, Literal::Deserialize(*bound, ptype));
SingleValueStructLike data(std::move(literal));
return term.Evaluate(data);
}

class CountAggregator : public BoundAggregate::Aggregator {
public:
explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {}
Expand All @@ -48,11 +78,32 @@ class CountAggregator : public BoundAggregate::Aggregator {
return {};
}

Literal GetResult() const override { return Literal::Long(count_); }
Status Update(const DataFile& file) override {
if (!valid_) {
return {};
}
if (!aggregate_.HasValue(file)) {
valid_ = false;
return {};
}
ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(file));
count_ += count;
return {};
}

Literal GetResult() const override {
if (!valid_) {
return Literal::Null(int64());
}
return Literal::Long(count_);
}

bool IsValid() const override { return valid_; }

private:
const CountAggregate& aggregate_;
int64_t count_ = 0;
bool valid_ = true;
};

class MaxAggregator : public BoundAggregate::Aggregator {
Expand All @@ -73,6 +124,7 @@ class MaxAggregator : public BoundAggregate::Aggregator {

if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::greater) {
Expand All @@ -82,11 +134,48 @@ class MaxAggregator : public BoundAggregate::Aggregator {
return {};
}

Literal GetResult() const override { return current_; }
Status Update(const DataFile& file) override {
if (!valid_) {
return {};
}
if (!aggregate_.HasValue(file)) {
valid_ = false;
return {};
}

ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
if (value.IsNull()) {
return {};
}
if (current_.IsNull()) {
current_ = std::move(value);
return {};
}

if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::greater) {
current_ = std::move(value);
}
return {};
}

Literal GetResult() const override {
if (!valid_) {
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
}
return current_;
}

bool IsValid() const override { return valid_; }

private:
const MaxAggregate& aggregate_;
Literal current_;
bool valid_ = true;
};

class MinAggregator : public BoundAggregate::Aggregator {
Expand All @@ -107,6 +196,7 @@ class MinAggregator : public BoundAggregate::Aggregator {

if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::less) {
Expand All @@ -115,13 +205,66 @@ class MinAggregator : public BoundAggregate::Aggregator {
return {};
}

Literal GetResult() const override { return current_; }
Status Update(const DataFile& file) override {
if (!valid_) {
return {};
}
if (!aggregate_.HasValue(file)) {
valid_ = false;
return {};
}

ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(file));
if (value.IsNull()) {
return {};
}
if (current_.IsNull()) {
current_ = std::move(value);
return {};
}

if (auto ordering = value <=> current_;
ordering == std::partial_ordering::unordered) {
valid_ = false;
return InvalidArgument("Cannot compare literal {} with current value {}",
value.ToString(), current_.ToString());
} else if (ordering == std::partial_ordering::less) {
current_ = std::move(value);
}
return {};
}

Literal GetResult() const override {
if (!valid_) {
return Literal::Null(GetPrimitiveType(*aggregate_.term()));
}
return current_;
}

bool IsValid() const override { return valid_; }

private:
const MinAggregate& aggregate_;
Literal current_;
bool valid_ = true;
};

template <typename T>
std::optional<T> GetMapValue(const std::map<int32_t, T>& map, int32_t key) {
auto iter = map.find(key);
if (iter == map.end()) {
return std::nullopt;
}
return iter->second;
}

int32_t GetFieldId(const std::shared_ptr<BoundTerm>& term) {
ICEBERG_DCHECK(term != nullptr, "Aggregate term should not be null");
auto ref = term->reference();
ICEBERG_DCHECK(ref != nullptr, "Aggregate term reference should not be null");
return ref->field().field_id();
}

} // namespace

template <TermType T>
Expand Down Expand Up @@ -149,7 +292,11 @@ std::string Aggregate<T>::ToString() const {
// -------------------- CountAggregate --------------------

Result<Literal> CountAggregate::Evaluate(const StructLike& data) const {
return CountFor(data).transform([](int64_t count) { return Literal::Long(count); });
return CountFor(data).transform(Literal::Long);
}

Result<Literal> CountAggregate::Evaluate(const DataFile& file) const {
return CountFor(file).transform(Literal::Long);
}

std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const {
Expand All @@ -173,6 +320,22 @@ Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const {
[](const auto& val) { return val.IsNull() ? 0 : 1; });
}

Result<int64_t> CountNonNullAggregate::CountFor(const DataFile& file) const {
auto field_id = GetFieldId(term());
if (!HasValue(file)) {
return NotFound("Missing metrics for field id {}", field_id);
}
auto value_count = GetMapValue(file.value_counts, field_id).value();
auto null_count = GetMapValue(file.null_value_counts, field_id).value();
return value_count - null_count;
}

bool CountNonNullAggregate::HasValue(const DataFile& file) const {
auto field_id = GetFieldId(term());
return file.value_counts.contains(field_id) &&
file.null_value_counts.contains(field_id);
}

CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term)
: CountAggregate(Expression::Operation::kCountNull, std::move(term)) {}

Expand All @@ -189,6 +352,18 @@ Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const {
[](const auto& val) { return val.IsNull() ? 1 : 0; });
}

Result<int64_t> CountNullAggregate::CountFor(const DataFile& file) const {
auto field_id = GetFieldId(term());
if (!HasValue(file)) {
return NotFound("Missing metrics for field id {}", field_id);
}
return GetMapValue(file.null_value_counts, field_id).value();
}

bool CountNullAggregate::HasValue(const DataFile& file) const {
return file.null_value_counts.contains(GetFieldId(term()));
}

CountStarAggregate::CountStarAggregate()
: CountAggregate(Expression::Operation::kCountStar, nullptr) {}

Expand All @@ -200,6 +375,17 @@ Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const {
return 1;
}

Result<int64_t> CountStarAggregate::CountFor(const DataFile& file) const {
if (!HasValue(file)) {
return NotFound("Record count is missing");
}
return file.record_count;
}

bool CountStarAggregate::HasValue(const DataFile& file) const {
return file.record_count >= 0;
}

MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term)
: BoundAggregate(Expression::Operation::kMax, std::move(term)) {}

Expand All @@ -211,10 +397,26 @@ Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
return term()->Evaluate(data);
}

Result<Literal> MaxAggregate::Evaluate(const DataFile& file) const {
auto field_id = GetFieldId(term());
auto upper = GetMapValue(file.upper_bounds, field_id);
return EvaluateBoundTerm(*term(), upper);
}

std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const {
return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this));
}

bool MaxAggregate::HasValue(const DataFile& file) const {
auto field_id = GetFieldId(term());
bool has_bound = file.upper_bounds.contains(field_id);
auto value_count = GetMapValue(file.value_counts, field_id);
auto null_count = GetMapValue(file.null_value_counts, field_id);
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
null_count.value() == value_count.value();
return has_bound || all_null;
}

MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term)
: BoundAggregate(Expression::Operation::kMin, std::move(term)) {}

Expand All @@ -226,10 +428,26 @@ Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
return term()->Evaluate(data);
}

Result<Literal> MinAggregate::Evaluate(const DataFile& file) const {
auto field_id = GetFieldId(term());
auto lower = GetMapValue(file.lower_bounds, field_id);
return EvaluateBoundTerm(*term(), lower);
}

std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const {
return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this));
}

bool MinAggregate::HasValue(const DataFile& file) const {
auto field_id = GetFieldId(term());
bool has_bound = file.lower_bounds.contains(field_id);
auto value_count = GetMapValue(file.value_counts, field_id);
auto null_count = GetMapValue(file.null_value_counts, field_id);
bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() &&
null_count.value() == value_count.value();
return has_bound || all_null;
}

// -------------------- Unbound binding --------------------

template <typename B>
Expand Down Expand Up @@ -275,8 +493,10 @@ Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make(
}

template class Aggregate<UnboundTerm<BoundReference>>;
template class Aggregate<UnboundTerm<BoundTransform>>;
template class Aggregate<BoundTerm>;
template class UnboundAggregateImpl<BoundReference>;
template class UnboundAggregateImpl<BoundTransform>;

// -------------------- AggregateEvaluator --------------------

Expand All @@ -296,6 +516,13 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
return {};
}

Status Update(const DataFile& file) override {
for (auto& aggregator : aggregators_) {
ICEBERG_RETURN_UNEXPECTED(aggregator->Update(file));
}
return {};
}

Result<std::span<const Literal>> GetResults() const override {
results_.clear();
results_.reserve(aggregates_.size());
Expand All @@ -315,6 +542,10 @@ class AggregateEvaluatorImpl : public AggregateEvaluator {
return all.front();
}

bool AllAggregatorsValid() const override {
return std::ranges::all_of(aggregators_, &BoundAggregate::Aggregator::IsValid);
}

private:
std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;
Expand Down
Loading
Loading