Skip to content

Commit 91e3af6

Browse files
author
Patryk Dudziński
committed
[C++][sequencer][tests] fixed tests to use new Ordering requiremnt. In GroupByNode and ScalarAggregateNode moved inputs from NodeArgs to Make fucniton to establish ordering
1 parent 9a5d024 commit 91e3af6

File tree

6 files changed

+97
-89
lines changed

6 files changed

+97
-89
lines changed

cpp/src/arrow/acero/aggregate_internal.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ Result<std::shared_ptr<Schema>> MakeOutputSchema(
214214
ARROW_ASSIGN_OR_RAISE(auto args,
215215
ScalarAggregateNode::MakeAggregateNodeArgs(
216216
input_schema, keys, segment_keys, aggregates, exec_ctx,
217-
/*concurrency=*/1, {}));
217+
/*concurrency=*/1));
218218
return std::move(args.output_schema);
219219
} else {
220220
ARROW_ASSIGN_OR_RAISE(
221221
auto args, GroupByNode::MakeAggregateNodeArgs(input_schema, keys, segment_keys,
222-
aggregates, exec_ctx, {}));
222+
aggregates, exec_ctx));
223223
return std::move(args.output_schema);
224224
}
225225
}

cpp/src/arrow/acero/aggregate_internal.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct AggregateNodeArgs {
9898
std::vector<const KernelType*> kernels;
9999
std::vector<std::vector<TypeHolder>> kernel_intypes;
100100
std::vector<std::vector<std::unique_ptr<KernelState>>> states;
101-
Ordering ordering;
101+
bool requires_ordering;
102102
};
103103

104104
std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>& in_types);
@@ -194,7 +194,7 @@ class ScalarAggregateNode : public ExecNode,
194194
static Result<AggregateNodeArgs<ScalarAggregateKernel>> MakeAggregateNodeArgs(
195195
const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
196196
const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
197-
ExecContext* exec_ctx, size_t concurrency, std::vector<ExecNode*> inputs);
197+
ExecContext* exec_ctx, size_t concurrency);
198198

199199
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
200200
const ExecNodeOptions& options);
@@ -291,7 +291,7 @@ class GroupByNode : public ExecNode,
291291
static Result<AggregateNodeArgs<HashAggregateKernel>> MakeAggregateNodeArgs(
292292
const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
293293
const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
294-
ExecContext* ctx, std::vector<ExecNode*> inputs);
294+
ExecContext* ctx);
295295

296296
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
297297
const ExecNodeOptions& options);

cpp/src/arrow/acero/groupby_aggregate_node.cc

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ Status GroupByNode::Init() {
7474
Result<AggregateNodeArgs<HashAggregateKernel>> GroupByNode::MakeAggregateNodeArgs(
7575
const std::shared_ptr<Schema>& input_schema, const std::vector<FieldRef>& keys,
7676
const std::vector<FieldRef>& segment_keys, const std::vector<Aggregate>& aggs,
77-
ExecContext* ctx, std::vector<ExecNode*> inputs) {
77+
ExecContext* ctx) {
7878
// Find input field indices for key fields
7979
std::vector<int> key_field_ids(keys.size());
8080
for (size_t i = 0; i < keys.size(); ++i) {
@@ -167,37 +167,6 @@ Result<AggregateNodeArgs<HashAggregateKernel>> GroupByNode::MakeAggregateNodeArg
167167
for (size_t i = 0; i < aggs.size(); ++i) {
168168
output_fields[base + i] = agg_result_fields[i]->WithName(aggs[i].name);
169169
}
170-
Ordering out_ordering = Ordering::Unordered();
171-
if (requires_ordering && inputs[0]->ordering().is_implicit()) {
172-
out_ordering = Ordering::Implicit();
173-
} else if (requires_ordering) {
174-
std::vector<compute::SortKey> out_sort_keys;
175-
std::unordered_set<int> segmented_key_field_id_set(segment_key_field_ids.begin(),
176-
segment_key_field_ids.end());
177-
// Propagate output sorting only by segmented keys excluding sorting by regualr keys
178-
// since this will break the segmentation.
179-
for (auto key : inputs[0]->ordering().sort_keys()) {
180-
ARROW_ASSIGN_OR_RAISE(auto match, key.target.FindOne(*input_schema));
181-
if (segmented_key_field_id_set.find(match[0]) != segmented_key_field_id_set.end()) {
182-
out_sort_keys.emplace_back(key);
183-
} else {
184-
break;
185-
}
186-
}
187-
if (out_sort_keys.size() > 0)
188-
out_ordering = Ordering(out_sort_keys);
189-
else
190-
out_ordering = Ordering::Implicit();
191-
}
192-
193-
if (!out_ordering.is_unordered()) {
194-
if (inputs[0]->ordering().is_unordered()) {
195-
return Status::Invalid(
196-
"Aggregate node's input has no meaningful ordering and so limit/offset will be "
197-
"non-deterministic. Please establish order in some way (e.g. by inserting an "
198-
"order_by node)");
199-
}
200-
}
201170

202171
return AggregateNodeArgs<HashAggregateKernel>{schema(std::move(output_fields)),
203172
std::move(key_field_ids),
@@ -208,7 +177,7 @@ Result<AggregateNodeArgs<HashAggregateKernel>> GroupByNode::MakeAggregateNodeArg
208177
std::move(agg_kernels),
209178
std::move(agg_src_types),
210179
/*states=*/{},
211-
std::move(out_ordering)};
180+
requires_ordering};
212181
}
213182

214183
Result<ExecNode*> GroupByNode::Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
@@ -224,14 +193,45 @@ Result<ExecNode*> GroupByNode::Make(ExecPlan* plan, std::vector<ExecNode*> input
224193

225194
const auto& input_schema = input->output_schema();
226195
auto exec_ctx = plan->query_context()->exec_context();
227-
ARROW_ASSIGN_OR_RAISE(auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys,
228-
aggs, exec_ctx, inputs));
196+
ARROW_ASSIGN_OR_RAISE(
197+
auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggs, exec_ctx));
229198

199+
Ordering out_ordering = Ordering::Unordered();
200+
if (args.requires_ordering && input->ordering().is_implicit()) {
201+
out_ordering = Ordering::Implicit();
202+
} else if (args.requires_ordering) {
203+
std::vector<compute::SortKey> out_sort_keys;
204+
std::unordered_set<int> segmented_key_field_id_set(args.segment_key_field_ids.begin(),
205+
args.segment_key_field_ids.end());
206+
// Propagate output sorting only by segmented keys excluding sorting by regular keys
207+
// since this will break the segmentation.
208+
for (auto key : input->ordering().sort_keys()) {
209+
ARROW_ASSIGN_OR_RAISE(auto match, key.target.FindOne(*input_schema));
210+
if (segmented_key_field_id_set.find(match[0]) != segmented_key_field_id_set.end()) {
211+
out_sort_keys.emplace_back(key);
212+
} else {
213+
break;
214+
}
215+
}
216+
if (out_sort_keys.size() > 0) {
217+
out_ordering = Ordering(out_sort_keys);
218+
} else {
219+
out_ordering = Ordering::Implicit();
220+
}
221+
}
222+
if (!out_ordering.is_unordered()) {
223+
if (inputs[0]->ordering().is_unordered()) {
224+
return Status::Invalid(
225+
"Aggregate node's input has no meaningful ordering and so limit/offset will be "
226+
"non-deterministic. Please establish order in some way (e.g. by inserting an "
227+
"order_by node)");
228+
}
229+
}
230230
return input->plan()->EmplaceNode<GroupByNode>(
231231
input, std::move(args.output_schema), std::move(args.grouping_key_field_ids),
232232
std::move(args.segment_key_field_ids), std::move(args.segmenter),
233233
std::move(args.kernel_intypes), std::move(args.target_fieldsets),
234-
std::move(args.aggregates), std::move(args.kernels), std::move(args.ordering));
234+
std::move(args.aggregates), std::move(args.kernels), std::move(out_ordering));
235235
}
236236

237237
Status GroupByNode::ResetKernelStates() {

cpp/src/arrow/acero/hash_aggregate_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "arrow/compute/cast.h"
4040
#include "arrow/compute/exec.h"
4141
#include "arrow/compute/exec_internal.h"
42+
#include "arrow/compute/ordering.h"
4243
#include "arrow/compute/registry.h"
4344
#include "arrow/compute/row/grouper.h"
4445
#include "arrow/table.h"
@@ -315,7 +316,8 @@ Result<Datum> RunGroupBy(const BatchesWithSchema& input,
315316
Declaration::Sequence(
316317
{
317318
{"source",
318-
SourceNodeOptions{input.schema, input.gen(use_threads, /*slow=*/false)}},
319+
SourceNodeOptions{input.schema, input.gen(use_threads, /*slow=*/false),
320+
Ordering::Implicit()}},
319321
{"aggregate", AggregateNodeOptions{aggregates, std::move(keys),
320322
std::move(segment_keys)}},
321323
{"sink", SinkNodeOptions{&sink_gen}},

cpp/src/arrow/acero/plan_test.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,12 +1651,13 @@ TEST(ExecPlanExecution, SegmentedAggregationWithMultiThreading) {
16511651
data.schema = schema({field("i32", int32())});
16521652
Declaration plan = Declaration::Sequence(
16531653
{{"source",
1654-
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}},
1654+
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false),
1655+
Ordering::Unordered()}},
16551656
{"aggregate", AggregateNodeOptions{/*aggregates=*/{
16561657
{"count", nullptr, "i32", "count(i32)"},
16571658
},
16581659
/*keys=*/{}, /*segment_keys=*/{"i32"}}}});
1659-
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, HasSubstr("multi-threaded"),
1660+
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("meaningful ordering"),
16601661
DeclarationToExecBatches(std::move(plan)));
16611662
}
16621663

@@ -1674,7 +1675,8 @@ TEST(ExecPlanExecution, SegmentedAggregationWithOneSegment) {
16741675

16751676
Declaration plan = Declaration::Sequence(
16761677
{{"source",
1677-
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}},
1678+
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false),
1679+
Ordering::Implicit()}},
16781680
{"aggregate", AggregateNodeOptions{/*aggregates=*/{
16791681
{"hash_sum", nullptr, "c", "sum(c)"},
16801682
{"hash_mean", nullptr, "c", "mean(c)"},
@@ -1703,7 +1705,8 @@ TEST(ExecPlanExecution, SegmentedAggregationWithTwoSegments) {
17031705

17041706
Declaration plan = Declaration::Sequence(
17051707
{{"source",
1706-
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}},
1708+
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false),
1709+
Ordering::Implicit()}},
17071710
{"aggregate", AggregateNodeOptions{/*aggregates=*/{
17081711
{"hash_sum", nullptr, "c", "sum(c)"},
17091712
{"hash_mean", nullptr, "c", "mean(c)"},
@@ -1733,7 +1736,8 @@ TEST(ExecPlanExecution, SegmentedAggregationWithBatchCrossingSegment) {
17331736

17341737
Declaration plan = Declaration::Sequence(
17351738
{{"source",
1736-
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}},
1739+
SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false),
1740+
Ordering::Implicit()}},
17371741
{"aggregate", AggregateNodeOptions{/*aggregates=*/{
17381742
{"hash_sum", nullptr, "c", "sum(c)"},
17391743
{"hash_mean", nullptr, "c", "mean(c)"},

cpp/src/arrow/acero/scalar_aggregate_node.cc

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ ScalarAggregateNode::MakeAggregateNodeArgs(const std::shared_ptr<Schema>& input_
6262
const std::vector<FieldRef>& keys,
6363
const std::vector<FieldRef>& segment_keys,
6464
const std::vector<Aggregate>& aggs,
65-
ExecContext* exec_ctx, size_t concurrency,
66-
std::vector<ExecNode*> inputs) {
65+
ExecContext* exec_ctx, size_t concurrency) {
6766
// Copy (need to modify options pointer below)
6867
std::vector<Aggregate> aggregates(aggs);
6968
std::vector<int> segment_field_ids(segment_keys.size());
@@ -158,27 +157,61 @@ ScalarAggregateNode::MakeAggregateNodeArgs(const std::shared_ptr<Schema>& input_
158157
fields[base + i] = field(aggregates[i].name, out_type.GetSharedPtr());
159158
}
160159

160+
return AggregateNodeArgs<ScalarAggregateKernel>{schema(std::move(fields)),
161+
/*grouping_key_field_ids=*/{},
162+
std::move(segment_field_ids),
163+
std::move(segmenter),
164+
std::move(target_fieldsets),
165+
std::move(aggregates),
166+
std::move(kernels),
167+
std::move(kernel_intypes),
168+
std::move(states),
169+
requires_ordering};
170+
}
171+
172+
Result<ExecNode*> ScalarAggregateNode::Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
173+
const ExecNodeOptions& options) {
174+
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ScalarAggregateNode"));
175+
176+
const auto& aggregate_options = checked_cast<const AggregateNodeOptions&>(options);
177+
auto aggregates = aggregate_options.aggregates;
178+
const auto& keys = aggregate_options.keys;
179+
const auto& segment_keys = aggregate_options.segment_keys;
180+
const auto input = inputs[0];
181+
const auto concurrency = plan->query_context()->max_concurrency();
182+
183+
if (keys.size() > 0) {
184+
return Status::Invalid("Scalar aggregation with some key");
185+
}
186+
187+
const auto& input_schema = inputs[0]->output_schema();
188+
auto exec_ctx = plan->query_context()->exec_context();
189+
190+
ARROW_ASSIGN_OR_RAISE(
191+
auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates,
192+
exec_ctx, concurrency));
161193
Ordering out_ordering = Ordering::Unordered();
162-
if (requires_ordering && inputs[0]->ordering().is_implicit()) {
194+
if (args.requires_ordering && input->ordering().is_implicit()) {
163195
out_ordering = Ordering::Implicit();
164-
} else if (requires_ordering) {
196+
} else if (args.requires_ordering) {
165197
std::vector<compute::SortKey> out_sort_keys;
166-
std::unordered_set<int> segmented_key_field_id_set(segment_field_ids.begin(),
167-
segment_field_ids.end());
168-
// Propagate output sorting only by segmented keys excluding sorting by regualr keys
198+
std::unordered_set<int> segmented_key_field_id_set(args.segment_key_field_ids.begin(),
199+
args.segment_key_field_ids.end());
200+
// Propagate output sorting only by segmented keys excluding sorting by regular keys
169201
// since this will break the segmentation.
170-
for (auto key : inputs[0]->ordering().sort_keys()) {
202+
for (auto key : input->ordering().sort_keys()) {
171203
ARROW_ASSIGN_OR_RAISE(auto match, key.target.FindOne(*input_schema));
172204
if (segmented_key_field_id_set.find(match[0]) != segmented_key_field_id_set.end()) {
173205
out_sort_keys.emplace_back(key);
174206
} else {
175207
break;
176208
}
177209
}
178-
if (out_sort_keys.size() > 0)
210+
if (out_sort_keys.size() > 0) {
179211
out_ordering = Ordering(out_sort_keys);
180-
else
212+
} else {
181213
out_ordering = Ordering::Implicit();
214+
}
182215
}
183216

184217
if (!out_ordering.is_unordered()) {
@@ -189,42 +222,11 @@ ScalarAggregateNode::MakeAggregateNodeArgs(const std::shared_ptr<Schema>& input_
189222
"order_by node)");
190223
}
191224
}
192-
193-
return AggregateNodeArgs<ScalarAggregateKernel>{
194-
schema(std::move(fields)),
195-
/*grouping_key_field_ids=*/{}, std::move(segment_field_ids),
196-
std::move(segmenter), std::move(target_fieldsets),
197-
std::move(aggregates), std::move(kernels),
198-
std::move(kernel_intypes), std::move(states),
199-
std::move(out_ordering)};
200-
}
201-
202-
Result<ExecNode*> ScalarAggregateNode::Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
203-
const ExecNodeOptions& options) {
204-
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "ScalarAggregateNode"));
205-
206-
const auto& aggregate_options = checked_cast<const AggregateNodeOptions&>(options);
207-
auto aggregates = aggregate_options.aggregates;
208-
const auto& keys = aggregate_options.keys;
209-
const auto& segment_keys = aggregate_options.segment_keys;
210-
const auto concurrency = plan->query_context()->max_concurrency();
211-
212-
if (keys.size() > 0) {
213-
return Status::Invalid("Scalar aggregation with some key");
214-
}
215-
216-
const auto& input_schema = inputs[0]->output_schema();
217-
auto exec_ctx = plan->query_context()->exec_context();
218-
219-
ARROW_ASSIGN_OR_RAISE(
220-
auto args, MakeAggregateNodeArgs(input_schema, keys, segment_keys, aggregates,
221-
exec_ctx, concurrency, inputs));
222-
223225
return plan->EmplaceNode<ScalarAggregateNode>(
224226
plan, std::move(inputs), std::move(args.output_schema), std::move(args.segmenter),
225227
std::move(args.segment_key_field_ids), std::move(args.target_fieldsets),
226228
std::move(args.aggregates), std::move(args.kernels), std::move(args.kernel_intypes),
227-
std::move(args.states), std::move(args.ordering));
229+
std::move(args.states), std::move(out_ordering));
228230
}
229231

230232
Status ScalarAggregateNode::DoConsume(const ExecSpan& batch, size_t thread_index) {

0 commit comments

Comments
 (0)