|
19 | 19 |
|
20 | 20 | #include "iceberg/expression/aggregate.h" |
21 | 21 |
|
| 22 | +#include <algorithm> |
22 | 23 | #include <format> |
23 | 24 | #include <map> |
24 | 25 | #include <optional> |
@@ -231,14 +232,6 @@ class MinAggregator : public BoundAggregate::Aggregator { |
231 | 232 | bool valid_ = true; |
232 | 233 | }; |
233 | 234 |
|
234 | | -bool HasMapKey(const std::map<int32_t, int64_t>& map, int32_t key) { |
235 | | - return map.contains(key); |
236 | | -} |
237 | | - |
238 | | -bool HasMapKey(const std::map<int32_t, std::vector<uint8_t>>& map, int32_t key) { |
239 | | - return map.contains(key); |
240 | | -} |
241 | | - |
242 | 235 | template <typename T> |
243 | 236 | std::optional<T> GetMapValue(const std::map<int32_t, T>& map, int32_t key) { |
244 | 237 | auto iter = map.find(key); |
@@ -282,12 +275,11 @@ std::string Aggregate<T>::ToString() const { |
282 | 275 | // -------------------- CountAggregate -------------------- |
283 | 276 |
|
284 | 277 | Result<Literal> CountAggregate::Evaluate(const StructLike& data) const { |
285 | | - return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); |
| 278 | + return CountFor(data).transform(Literal::Long); |
286 | 279 | } |
287 | 280 |
|
288 | 281 | Result<Literal> CountAggregate::Evaluate(const DataFile& file) const { |
289 | | - ICEBERG_ASSIGN_OR_RAISE(auto count, CountFor(file)); |
290 | | - return Literal::Long(count); |
| 282 | + return CountFor(file).transform(Literal::Long); |
291 | 283 | } |
292 | 284 |
|
293 | 285 | std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const { |
@@ -323,8 +315,8 @@ Result<int64_t> CountNonNullAggregate::CountFor(const DataFile& file) const { |
323 | 315 |
|
324 | 316 | bool CountNonNullAggregate::HasValue(const DataFile& file) const { |
325 | 317 | auto field_id = GetFieldId(term()); |
326 | | - return HasMapKey(file.value_counts, field_id) && |
327 | | - HasMapKey(file.null_value_counts, field_id); |
| 318 | + return file.value_counts.contains(field_id) && |
| 319 | + file.null_value_counts.contains(field_id); |
328 | 320 | } |
329 | 321 |
|
330 | 322 | CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term) |
@@ -352,7 +344,7 @@ Result<int64_t> CountNullAggregate::CountFor(const DataFile& file) const { |
352 | 344 | } |
353 | 345 |
|
354 | 346 | bool CountNullAggregate::HasValue(const DataFile& file) const { |
355 | | - return HasMapKey(file.null_value_counts, GetFieldId(term())); |
| 347 | + return file.null_value_counts.contains(GetFieldId(term())); |
356 | 348 | } |
357 | 349 |
|
358 | 350 | CountStarAggregate::CountStarAggregate() |
@@ -400,7 +392,7 @@ std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const |
400 | 392 |
|
401 | 393 | bool MaxAggregate::HasValue(const DataFile& file) const { |
402 | 394 | auto field_id = GetFieldId(term()); |
403 | | - bool has_bound = HasMapKey(file.upper_bounds, field_id); |
| 395 | + bool has_bound = file.upper_bounds.contains(field_id); |
404 | 396 | auto value_count = GetMapValue(file.value_counts, field_id); |
405 | 397 | auto null_count = GetMapValue(file.null_value_counts, field_id); |
406 | 398 | bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && |
@@ -431,7 +423,7 @@ std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const |
431 | 423 |
|
432 | 424 | bool MinAggregate::HasValue(const DataFile& file) const { |
433 | 425 | auto field_id = GetFieldId(term()); |
434 | | - bool has_bound = HasMapKey(file.lower_bounds, field_id); |
| 426 | + bool has_bound = file.lower_bounds.contains(field_id); |
435 | 427 | auto value_count = GetMapValue(file.value_counts, field_id); |
436 | 428 | auto null_count = GetMapValue(file.null_value_counts, field_id); |
437 | 429 | bool all_null = value_count.has_value() && *value_count > 0 && null_count.has_value() && |
@@ -534,12 +526,7 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { |
534 | 526 | } |
535 | 527 |
|
536 | 528 | bool AllAggregatorsValid() const override { |
537 | | - for (const auto& aggregator : aggregators_) { |
538 | | - if (!aggregator->IsValid()) { |
539 | | - return false; |
540 | | - } |
541 | | - } |
542 | | - return true; |
| 529 | + return std::ranges::all_of(aggregators_, &BoundAggregate::Aggregator::IsValid); |
543 | 530 | } |
544 | 531 |
|
545 | 532 | private: |
|
0 commit comments