|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include "iceberg/expression/aggregate.h" |
| 21 | + |
| 22 | +#include <format> |
| 23 | +#include <optional> |
| 24 | +#include <vector> |
| 25 | + |
| 26 | +#include "iceberg/expression/literal.h" |
| 27 | +#include "iceberg/row/struct_like.h" |
| 28 | +#include "iceberg/type.h" |
| 29 | +#include "iceberg/util/checked_cast.h" |
| 30 | +#include "iceberg/util/macros.h" |
| 31 | + |
| 32 | +namespace iceberg { |
| 33 | + |
| 34 | +namespace { |
| 35 | + |
| 36 | +std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) { |
| 37 | + ICEBERG_DCHECK(term.type()->is_primitive(), "Value aggregate term should be primitive"); |
| 38 | + return internal::checked_pointer_cast<PrimitiveType>(term.type()); |
| 39 | +} |
| 40 | + |
| 41 | +class CountAggregator : public BoundAggregate::Aggregator { |
| 42 | + public: |
| 43 | + explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {} |
| 44 | + |
| 45 | + Status Update(const StructLike& row) override { |
| 46 | + ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(row)); |
| 47 | + count_ += count; |
| 48 | + return {}; |
| 49 | + } |
| 50 | + |
| 51 | + Literal GetResult() const override { return Literal::Long(count_); } |
| 52 | + |
| 53 | + private: |
| 54 | + const CountAggregate& aggregate_; |
| 55 | + int64_t count_ = 0; |
| 56 | +}; |
| 57 | + |
| 58 | +class MaxAggregator : public BoundAggregate::Aggregator { |
| 59 | + public: |
| 60 | + explicit MaxAggregator(const MaxAggregate& aggregate) |
| 61 | + : aggregate_(aggregate), |
| 62 | + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} |
| 63 | + |
| 64 | + Status Update(const StructLike& data) override { |
| 65 | + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); |
| 66 | + if (value.IsNull()) { |
| 67 | + return {}; |
| 68 | + } |
| 69 | + if (current_.IsNull()) { |
| 70 | + current_ = std::move(value); |
| 71 | + return {}; |
| 72 | + } |
| 73 | + |
| 74 | + if (auto ordering = value <=> current_; |
| 75 | + ordering == std::partial_ordering::unordered) { |
| 76 | + return InvalidArgument("Cannot compare literal {} with current value {}", |
| 77 | + value.ToString(), current_.ToString()); |
| 78 | + } else if (ordering == std::partial_ordering::greater) { |
| 79 | + current_ = std::move(value); |
| 80 | + } |
| 81 | + |
| 82 | + return {}; |
| 83 | + } |
| 84 | + |
| 85 | + Literal GetResult() const override { return current_; } |
| 86 | + |
| 87 | + private: |
| 88 | + const MaxAggregate& aggregate_; |
| 89 | + Literal current_; |
| 90 | +}; |
| 91 | + |
| 92 | +class MinAggregator : public BoundAggregate::Aggregator { |
| 93 | + public: |
| 94 | + explicit MinAggregator(const MinAggregate& aggregate) |
| 95 | + : aggregate_(aggregate), |
| 96 | + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} |
| 97 | + |
| 98 | + Status Update(const StructLike& data) override { |
| 99 | + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); |
| 100 | + if (value.IsNull()) { |
| 101 | + return {}; |
| 102 | + } |
| 103 | + if (current_.IsNull()) { |
| 104 | + current_ = std::move(value); |
| 105 | + return {}; |
| 106 | + } |
| 107 | + |
| 108 | + if (auto ordering = value <=> current_; |
| 109 | + ordering == std::partial_ordering::unordered) { |
| 110 | + return InvalidArgument("Cannot compare literal {} with current value {}", |
| 111 | + value.ToString(), current_.ToString()); |
| 112 | + } else if (ordering == std::partial_ordering::less) { |
| 113 | + current_ = std::move(value); |
| 114 | + } |
| 115 | + return {}; |
| 116 | + } |
| 117 | + |
| 118 | + Literal GetResult() const override { return current_; } |
| 119 | + |
| 120 | + private: |
| 121 | + const MinAggregate& aggregate_; |
| 122 | + Literal current_; |
| 123 | +}; |
| 124 | + |
| 125 | +} // namespace |
| 126 | + |
| 127 | +template <TermType T> |
| 128 | +std::string Aggregate<T>::ToString() const { |
| 129 | + ICEBERG_DCHECK(IsSupportedOp(op()), "Unexpected aggregate operation"); |
| 130 | + ICEBERG_DCHECK(op() == Expression::Operation::kCountStar || term() != nullptr, |
| 131 | + "Aggregate term should not be null except for COUNT(*)"); |
| 132 | + |
| 133 | + switch (op()) { |
| 134 | + case Expression::Operation::kCount: |
| 135 | + return std::format("count({})", term()->ToString()); |
| 136 | + case Expression::Operation::kCountNull: |
| 137 | + return std::format("count_if({} is null)", term()->ToString()); |
| 138 | + case Expression::Operation::kCountStar: |
| 139 | + return "count(*)"; |
| 140 | + case Expression::Operation::kMax: |
| 141 | + return std::format("max({})", term()->ToString()); |
| 142 | + case Expression::Operation::kMin: |
| 143 | + return std::format("min({})", term()->ToString()); |
| 144 | + default: |
| 145 | + return std::format("Invalid aggregate: {}", ::iceberg::ToString(op())); |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +// -------------------- CountAggregate -------------------- |
| 150 | + |
| 151 | +Result<Literal> CountAggregate::Evaluate(const StructLike& data) const { |
| 152 | + return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); |
| 153 | +} |
| 154 | + |
| 155 | +std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const { |
| 156 | + return std::unique_ptr<BoundAggregate::Aggregator>(new CountAggregator(*this)); |
| 157 | +} |
| 158 | + |
| 159 | +CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr<BoundTerm> term) |
| 160 | + : CountAggregate(Expression::Operation::kCount, std::move(term)) {} |
| 161 | + |
| 162 | +Result<std::unique_ptr<CountNonNullAggregate>> CountNonNullAggregate::Make( |
| 163 | + std::shared_ptr<BoundTerm> term) { |
| 164 | + if (!term) { |
| 165 | + return InvalidExpression("Bound count aggregate requires non-null term"); |
| 166 | + } |
| 167 | + return std::unique_ptr<CountNonNullAggregate>( |
| 168 | + new CountNonNullAggregate(std::move(term))); |
| 169 | +} |
| 170 | + |
| 171 | +Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const { |
| 172 | + return term()->Evaluate(data).transform( |
| 173 | + [](const auto& val) { return val.IsNull() ? 0 : 1; }); |
| 174 | +} |
| 175 | + |
| 176 | +CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term) |
| 177 | + : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} |
| 178 | + |
| 179 | +Result<std::unique_ptr<CountNullAggregate>> CountNullAggregate::Make( |
| 180 | + std::shared_ptr<BoundTerm> term) { |
| 181 | + if (!term) { |
| 182 | + return InvalidExpression("Bound count aggregate requires non-null term"); |
| 183 | + } |
| 184 | + return std::unique_ptr<CountNullAggregate>(new CountNullAggregate(std::move(term))); |
| 185 | +} |
| 186 | + |
| 187 | +Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const { |
| 188 | + return term()->Evaluate(data).transform( |
| 189 | + [](const auto& val) { return val.IsNull() ? 1 : 0; }); |
| 190 | +} |
| 191 | + |
| 192 | +CountStarAggregate::CountStarAggregate() |
| 193 | + : CountAggregate(Expression::Operation::kCountStar, nullptr) {} |
| 194 | + |
| 195 | +Result<std::unique_ptr<CountStarAggregate>> CountStarAggregate::Make() { |
| 196 | + return std::unique_ptr<CountStarAggregate>(new CountStarAggregate()); |
| 197 | +} |
| 198 | + |
| 199 | +Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const { |
| 200 | + return 1; |
| 201 | +} |
| 202 | + |
| 203 | +MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term) |
| 204 | + : BoundAggregate(Expression::Operation::kMax, std::move(term)) {} |
| 205 | + |
| 206 | +std::shared_ptr<MaxAggregate> MaxAggregate::Make(std::shared_ptr<BoundTerm> term) { |
| 207 | + return std::shared_ptr<MaxAggregate>(new MaxAggregate(std::move(term))); |
| 208 | +} |
| 209 | + |
| 210 | +Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const { |
| 211 | + return term()->Evaluate(data); |
| 212 | +} |
| 213 | + |
| 214 | +std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const { |
| 215 | + return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this)); |
| 216 | +} |
| 217 | + |
| 218 | +MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term) |
| 219 | + : BoundAggregate(Expression::Operation::kMin, std::move(term)) {} |
| 220 | + |
| 221 | +std::shared_ptr<MinAggregate> MinAggregate::Make(std::shared_ptr<BoundTerm> term) { |
| 222 | + return std::shared_ptr<MinAggregate>(new MinAggregate(std::move(term))); |
| 223 | +} |
| 224 | + |
| 225 | +Result<Literal> MinAggregate::Evaluate(const StructLike& data) const { |
| 226 | + return term()->Evaluate(data); |
| 227 | +} |
| 228 | + |
| 229 | +std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const { |
| 230 | + return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this)); |
| 231 | +} |
| 232 | + |
| 233 | +// -------------------- Unbound binding -------------------- |
| 234 | + |
| 235 | +template <typename B> |
| 236 | +Result<std::shared_ptr<Expression>> UnboundAggregateImpl<B>::Bind( |
| 237 | + const Schema& schema, bool case_sensitive) const { |
| 238 | + ICEBERG_DCHECK(UnboundAggregateImpl<B>::IsSupportedOp(this->op()), |
| 239 | + "Unexpected aggregate operation"); |
| 240 | + |
| 241 | + std::shared_ptr<B> bound_term; |
| 242 | + if (this->term()) { |
| 243 | + ICEBERG_ASSIGN_OR_RAISE(bound_term, this->term()->Bind(schema, case_sensitive)); |
| 244 | + } |
| 245 | + |
| 246 | + switch (this->op()) { |
| 247 | + case Expression::Operation::kCountStar: |
| 248 | + return CountStarAggregate::Make(); |
| 249 | + case Expression::Operation::kCount: |
| 250 | + return CountNonNullAggregate::Make(std::move(bound_term)); |
| 251 | + case Expression::Operation::kCountNull: |
| 252 | + return CountNullAggregate::Make(std::move(bound_term)); |
| 253 | + case Expression::Operation::kMax: |
| 254 | + return MaxAggregate::Make(std::move(bound_term)); |
| 255 | + case Expression::Operation::kMin: |
| 256 | + return MinAggregate::Make(std::move(bound_term)); |
| 257 | + default: |
| 258 | + return NotSupported("Unsupported aggregate operation: {}", |
| 259 | + ::iceberg::ToString(this->op())); |
| 260 | + } |
| 261 | +} |
| 262 | + |
| 263 | +template <typename B> |
| 264 | +Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make( |
| 265 | + Expression::Operation op, std::shared_ptr<UnboundTerm<B>> term) { |
| 266 | + if (!Aggregate<UnboundTerm<B>>::IsSupportedOp(op)) { |
| 267 | + return NotSupported("Unsupported aggregate operation: {}", ::iceberg::ToString(op)); |
| 268 | + } |
| 269 | + if (op != Expression::Operation::kCountStar && !term) { |
| 270 | + return InvalidExpression("Aggregate term cannot be null unless COUNT(*)"); |
| 271 | + } |
| 272 | + |
| 273 | + return std::shared_ptr<UnboundAggregateImpl<B>>( |
| 274 | + new UnboundAggregateImpl<B>(op, std::move(term))); |
| 275 | +} |
| 276 | + |
| 277 | +template class Aggregate<UnboundTerm<BoundReference>>; |
| 278 | +template class Aggregate<BoundTerm>; |
| 279 | +template class UnboundAggregateImpl<BoundReference>; |
| 280 | + |
| 281 | +// -------------------- AggregateEvaluator -------------------- |
| 282 | + |
| 283 | +namespace { |
| 284 | + |
| 285 | +class AggregateEvaluatorImpl : public AggregateEvaluator { |
| 286 | + public: |
| 287 | + AggregateEvaluatorImpl( |
| 288 | + std::vector<std::shared_ptr<BoundAggregate>> aggregates, |
| 289 | + std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators) |
| 290 | + : aggregates_(std::move(aggregates)), aggregators_(std::move(aggregators)) {} |
| 291 | + |
| 292 | + Status Update(const StructLike& data) override { |
| 293 | + for (auto& aggregator : aggregators_) { |
| 294 | + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(data)); |
| 295 | + } |
| 296 | + return {}; |
| 297 | + } |
| 298 | + |
| 299 | + Result<std::span<const Literal>> GetResults() const override { |
| 300 | + results_.clear(); |
| 301 | + results_.reserve(aggregates_.size()); |
| 302 | + for (const auto& aggregator : aggregators_) { |
| 303 | + results_.emplace_back(aggregator->GetResult()); |
| 304 | + } |
| 305 | + return std::span<const Literal>(results_); |
| 306 | + } |
| 307 | + |
| 308 | + Result<Literal> GetResult() const override { |
| 309 | + if (aggregates_.size() != 1) { |
| 310 | + return InvalidArgument( |
| 311 | + "GetResult() is only valid when evaluating a single aggregate"); |
| 312 | + } |
| 313 | + |
| 314 | + ICEBERG_ASSIGN_OR_RAISE(auto all, GetResults()); |
| 315 | + return all.front(); |
| 316 | + } |
| 317 | + |
| 318 | + private: |
| 319 | + std::vector<std::shared_ptr<BoundAggregate>> aggregates_; |
| 320 | + std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_; |
| 321 | + mutable std::vector<Literal> results_; |
| 322 | +}; |
| 323 | + |
| 324 | +} // namespace |
| 325 | + |
| 326 | +Result<std::unique_ptr<AggregateEvaluator>> AggregateEvaluator::Make( |
| 327 | + std::shared_ptr<BoundAggregate> aggregate) { |
| 328 | + std::vector<std::shared_ptr<BoundAggregate>> aggs; |
| 329 | + aggs.push_back(std::move(aggregate)); |
| 330 | + return Make(std::move(aggs)); |
| 331 | +} |
| 332 | + |
| 333 | +Result<std::unique_ptr<AggregateEvaluator>> AggregateEvaluator::Make( |
| 334 | + std::vector<std::shared_ptr<BoundAggregate>> aggregates) { |
| 335 | + if (aggregates.empty()) { |
| 336 | + return InvalidArgument("AggregateEvaluator requires at least one aggregate"); |
| 337 | + } |
| 338 | + std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators; |
| 339 | + aggregators.reserve(aggregates.size()); |
| 340 | + for (const auto& agg : aggregates) { |
| 341 | + aggregators.push_back(agg->NewAggregator()); |
| 342 | + } |
| 343 | + |
| 344 | + return std::unique_ptr<AggregateEvaluator>( |
| 345 | + new AggregateEvaluatorImpl(std::move(aggregates), std::move(aggregators))); |
| 346 | +} |
| 347 | + |
| 348 | +} // namespace iceberg |
0 commit comments