Skip to content

Commit dbcbdf2

Browse files
authored
feat: add aggregate expressions and evaluator (#335)
This PR addresses issue #330 by introducing aggregate expressions & execution support: * Add aggregate expression family (count / count_null / count_star / max / min) with bound/unbound types, visitor and binder support. * Add `AggregateEvaluator` for count/max/min execution over `StructLike` rows. * Expose aggregate factories in `Expressions` and wire into CMake/Meson builds with new aggregate tests.
1 parent dbc9c1c commit dbcbdf2

File tree

16 files changed

+1007
-5
lines changed

16 files changed

+1007
-5
lines changed

src/iceberg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ set(ICEBERG_INCLUDES "$<BUILD_INTERFACE:${PROJECT_BINARY_DIR}/src>"
2020
set(ICEBERG_SOURCES
2121
arrow_c_data_guard_internal.cc
2222
catalog/memory/in_memory_catalog.cc
23+
expression/aggregate.cc
2324
expression/binder.cc
2425
expression/evaluator.cc
2526
expression/expression.cc
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "iceberg/expression/aggregate.h"
21+
22+
#include <format>
23+
#include <optional>
24+
#include <vector>
25+
26+
#include "iceberg/expression/literal.h"
27+
#include "iceberg/row/struct_like.h"
28+
#include "iceberg/type.h"
29+
#include "iceberg/util/checked_cast.h"
30+
#include "iceberg/util/macros.h"
31+
32+
namespace iceberg {
33+
34+
namespace {
35+
36+
std::shared_ptr<PrimitiveType> GetPrimitiveType(const BoundTerm& term) {
37+
ICEBERG_DCHECK(term.type()->is_primitive(), "Value aggregate term should be primitive");
38+
return internal::checked_pointer_cast<PrimitiveType>(term.type());
39+
}
40+
41+
class CountAggregator : public BoundAggregate::Aggregator {
42+
public:
43+
explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {}
44+
45+
Status Update(const StructLike& row) override {
46+
ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(row));
47+
count_ += count;
48+
return {};
49+
}
50+
51+
Literal GetResult() const override { return Literal::Long(count_); }
52+
53+
private:
54+
const CountAggregate& aggregate_;
55+
int64_t count_ = 0;
56+
};
57+
58+
class MaxAggregator : public BoundAggregate::Aggregator {
59+
public:
60+
explicit MaxAggregator(const MaxAggregate& aggregate)
61+
: aggregate_(aggregate),
62+
current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {}
63+
64+
Status Update(const StructLike& data) override {
65+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data));
66+
if (value.IsNull()) {
67+
return {};
68+
}
69+
if (current_.IsNull()) {
70+
current_ = std::move(value);
71+
return {};
72+
}
73+
74+
if (auto ordering = value <=> current_;
75+
ordering == std::partial_ordering::unordered) {
76+
return InvalidArgument("Cannot compare literal {} with current value {}",
77+
value.ToString(), current_.ToString());
78+
} else if (ordering == std::partial_ordering::greater) {
79+
current_ = std::move(value);
80+
}
81+
82+
return {};
83+
}
84+
85+
Literal GetResult() const override { return current_; }
86+
87+
private:
88+
const MaxAggregate& aggregate_;
89+
Literal current_;
90+
};
91+
92+
class MinAggregator : public BoundAggregate::Aggregator {
93+
public:
94+
explicit MinAggregator(const MinAggregate& aggregate)
95+
: aggregate_(aggregate),
96+
current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {}
97+
98+
Status Update(const StructLike& data) override {
99+
ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data));
100+
if (value.IsNull()) {
101+
return {};
102+
}
103+
if (current_.IsNull()) {
104+
current_ = std::move(value);
105+
return {};
106+
}
107+
108+
if (auto ordering = value <=> current_;
109+
ordering == std::partial_ordering::unordered) {
110+
return InvalidArgument("Cannot compare literal {} with current value {}",
111+
value.ToString(), current_.ToString());
112+
} else if (ordering == std::partial_ordering::less) {
113+
current_ = std::move(value);
114+
}
115+
return {};
116+
}
117+
118+
Literal GetResult() const override { return current_; }
119+
120+
private:
121+
const MinAggregate& aggregate_;
122+
Literal current_;
123+
};
124+
125+
} // namespace
126+
127+
template <TermType T>
128+
std::string Aggregate<T>::ToString() const {
129+
ICEBERG_DCHECK(IsSupportedOp(op()), "Unexpected aggregate operation");
130+
ICEBERG_DCHECK(op() == Expression::Operation::kCountStar || term() != nullptr,
131+
"Aggregate term should not be null except for COUNT(*)");
132+
133+
switch (op()) {
134+
case Expression::Operation::kCount:
135+
return std::format("count({})", term()->ToString());
136+
case Expression::Operation::kCountNull:
137+
return std::format("count_if({} is null)", term()->ToString());
138+
case Expression::Operation::kCountStar:
139+
return "count(*)";
140+
case Expression::Operation::kMax:
141+
return std::format("max({})", term()->ToString());
142+
case Expression::Operation::kMin:
143+
return std::format("min({})", term()->ToString());
144+
default:
145+
return std::format("Invalid aggregate: {}", ::iceberg::ToString(op()));
146+
}
147+
}
148+
149+
// -------------------- CountAggregate --------------------
150+
151+
Result<Literal> CountAggregate::Evaluate(const StructLike& data) const {
152+
return CountFor(data).transform([](int64_t count) { return Literal::Long(count); });
153+
}
154+
155+
std::unique_ptr<BoundAggregate::Aggregator> CountAggregate::NewAggregator() const {
156+
return std::unique_ptr<BoundAggregate::Aggregator>(new CountAggregator(*this));
157+
}
158+
159+
CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr<BoundTerm> term)
160+
: CountAggregate(Expression::Operation::kCount, std::move(term)) {}
161+
162+
Result<std::unique_ptr<CountNonNullAggregate>> CountNonNullAggregate::Make(
163+
std::shared_ptr<BoundTerm> term) {
164+
if (!term) {
165+
return InvalidExpression("Bound count aggregate requires non-null term");
166+
}
167+
return std::unique_ptr<CountNonNullAggregate>(
168+
new CountNonNullAggregate(std::move(term)));
169+
}
170+
171+
Result<int64_t> CountNonNullAggregate::CountFor(const StructLike& data) const {
172+
return term()->Evaluate(data).transform(
173+
[](const auto& val) { return val.IsNull() ? 0 : 1; });
174+
}
175+
176+
CountNullAggregate::CountNullAggregate(std::shared_ptr<BoundTerm> term)
177+
: CountAggregate(Expression::Operation::kCountNull, std::move(term)) {}
178+
179+
Result<std::unique_ptr<CountNullAggregate>> CountNullAggregate::Make(
180+
std::shared_ptr<BoundTerm> term) {
181+
if (!term) {
182+
return InvalidExpression("Bound count aggregate requires non-null term");
183+
}
184+
return std::unique_ptr<CountNullAggregate>(new CountNullAggregate(std::move(term)));
185+
}
186+
187+
Result<int64_t> CountNullAggregate::CountFor(const StructLike& data) const {
188+
return term()->Evaluate(data).transform(
189+
[](const auto& val) { return val.IsNull() ? 1 : 0; });
190+
}
191+
192+
CountStarAggregate::CountStarAggregate()
193+
: CountAggregate(Expression::Operation::kCountStar, nullptr) {}
194+
195+
Result<std::unique_ptr<CountStarAggregate>> CountStarAggregate::Make() {
196+
return std::unique_ptr<CountStarAggregate>(new CountStarAggregate());
197+
}
198+
199+
Result<int64_t> CountStarAggregate::CountFor(const StructLike& /*data*/) const {
200+
return 1;
201+
}
202+
203+
MaxAggregate::MaxAggregate(std::shared_ptr<BoundTerm> term)
204+
: BoundAggregate(Expression::Operation::kMax, std::move(term)) {}
205+
206+
std::shared_ptr<MaxAggregate> MaxAggregate::Make(std::shared_ptr<BoundTerm> term) {
207+
return std::shared_ptr<MaxAggregate>(new MaxAggregate(std::move(term)));
208+
}
209+
210+
Result<Literal> MaxAggregate::Evaluate(const StructLike& data) const {
211+
return term()->Evaluate(data);
212+
}
213+
214+
std::unique_ptr<BoundAggregate::Aggregator> MaxAggregate::NewAggregator() const {
215+
return std::unique_ptr<BoundAggregate::Aggregator>(new MaxAggregator(*this));
216+
}
217+
218+
MinAggregate::MinAggregate(std::shared_ptr<BoundTerm> term)
219+
: BoundAggregate(Expression::Operation::kMin, std::move(term)) {}
220+
221+
std::shared_ptr<MinAggregate> MinAggregate::Make(std::shared_ptr<BoundTerm> term) {
222+
return std::shared_ptr<MinAggregate>(new MinAggregate(std::move(term)));
223+
}
224+
225+
Result<Literal> MinAggregate::Evaluate(const StructLike& data) const {
226+
return term()->Evaluate(data);
227+
}
228+
229+
std::unique_ptr<BoundAggregate::Aggregator> MinAggregate::NewAggregator() const {
230+
return std::unique_ptr<BoundAggregate::Aggregator>(new MinAggregator(*this));
231+
}
232+
233+
// -------------------- Unbound binding --------------------
234+
235+
template <typename B>
236+
Result<std::shared_ptr<Expression>> UnboundAggregateImpl<B>::Bind(
237+
const Schema& schema, bool case_sensitive) const {
238+
ICEBERG_DCHECK(UnboundAggregateImpl<B>::IsSupportedOp(this->op()),
239+
"Unexpected aggregate operation");
240+
241+
std::shared_ptr<B> bound_term;
242+
if (this->term()) {
243+
ICEBERG_ASSIGN_OR_RAISE(bound_term, this->term()->Bind(schema, case_sensitive));
244+
}
245+
246+
switch (this->op()) {
247+
case Expression::Operation::kCountStar:
248+
return CountStarAggregate::Make();
249+
case Expression::Operation::kCount:
250+
return CountNonNullAggregate::Make(std::move(bound_term));
251+
case Expression::Operation::kCountNull:
252+
return CountNullAggregate::Make(std::move(bound_term));
253+
case Expression::Operation::kMax:
254+
return MaxAggregate::Make(std::move(bound_term));
255+
case Expression::Operation::kMin:
256+
return MinAggregate::Make(std::move(bound_term));
257+
default:
258+
return NotSupported("Unsupported aggregate operation: {}",
259+
::iceberg::ToString(this->op()));
260+
}
261+
}
262+
263+
template <typename B>
264+
Result<std::shared_ptr<UnboundAggregateImpl<B>>> UnboundAggregateImpl<B>::Make(
265+
Expression::Operation op, std::shared_ptr<UnboundTerm<B>> term) {
266+
if (!Aggregate<UnboundTerm<B>>::IsSupportedOp(op)) {
267+
return NotSupported("Unsupported aggregate operation: {}", ::iceberg::ToString(op));
268+
}
269+
if (op != Expression::Operation::kCountStar && !term) {
270+
return InvalidExpression("Aggregate term cannot be null unless COUNT(*)");
271+
}
272+
273+
return std::shared_ptr<UnboundAggregateImpl<B>>(
274+
new UnboundAggregateImpl<B>(op, std::move(term)));
275+
}
276+
277+
template class Aggregate<UnboundTerm<BoundReference>>;
278+
template class Aggregate<BoundTerm>;
279+
template class UnboundAggregateImpl<BoundReference>;
280+
281+
// -------------------- AggregateEvaluator --------------------
282+
283+
namespace {
284+
285+
class AggregateEvaluatorImpl : public AggregateEvaluator {
286+
public:
287+
AggregateEvaluatorImpl(
288+
std::vector<std::shared_ptr<BoundAggregate>> aggregates,
289+
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators)
290+
: aggregates_(std::move(aggregates)), aggregators_(std::move(aggregators)) {}
291+
292+
Status Update(const StructLike& data) override {
293+
for (auto& aggregator : aggregators_) {
294+
ICEBERG_RETURN_UNEXPECTED(aggregator->Update(data));
295+
}
296+
return {};
297+
}
298+
299+
Result<std::span<const Literal>> GetResults() const override {
300+
results_.clear();
301+
results_.reserve(aggregates_.size());
302+
for (const auto& aggregator : aggregators_) {
303+
results_.emplace_back(aggregator->GetResult());
304+
}
305+
return std::span<const Literal>(results_);
306+
}
307+
308+
Result<Literal> GetResult() const override {
309+
if (aggregates_.size() != 1) {
310+
return InvalidArgument(
311+
"GetResult() is only valid when evaluating a single aggregate");
312+
}
313+
314+
ICEBERG_ASSIGN_OR_RAISE(auto all, GetResults());
315+
return all.front();
316+
}
317+
318+
private:
319+
std::vector<std::shared_ptr<BoundAggregate>> aggregates_;
320+
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators_;
321+
mutable std::vector<Literal> results_;
322+
};
323+
324+
} // namespace
325+
326+
Result<std::unique_ptr<AggregateEvaluator>> AggregateEvaluator::Make(
327+
std::shared_ptr<BoundAggregate> aggregate) {
328+
std::vector<std::shared_ptr<BoundAggregate>> aggs;
329+
aggs.push_back(std::move(aggregate));
330+
return Make(std::move(aggs));
331+
}
332+
333+
Result<std::unique_ptr<AggregateEvaluator>> AggregateEvaluator::Make(
334+
std::vector<std::shared_ptr<BoundAggregate>> aggregates) {
335+
if (aggregates.empty()) {
336+
return InvalidArgument("AggregateEvaluator requires at least one aggregate");
337+
}
338+
std::vector<std::unique_ptr<BoundAggregate::Aggregator>> aggregators;
339+
aggregators.reserve(aggregates.size());
340+
for (const auto& agg : aggregates) {
341+
aggregators.push_back(agg->NewAggregator());
342+
}
343+
344+
return std::unique_ptr<AggregateEvaluator>(
345+
new AggregateEvaluatorImpl(std::move(aggregates), std::move(aggregators)));
346+
}
347+
348+
} // namespace iceberg

0 commit comments

Comments
 (0)