Skip to content

Commit 4d625b3

Browse files
authored
GH-45732: [C++][Compute] Accept more pivot key types (#45945)
### Rationale for this change Allow the `pivot_wider` and `hash_pivot_wider` functions to accept an integral pivot key column, in addition to binary-like. Since the `key_names` option is a vector of strings, they are cast to the appropriate pivot key type for matching. ### Are these changes tested? Yes, by new unit tests. ### Are there any user-facing changes? No. * GitHub Issue: #45732 Lead-authored-by: Antoine Pitrou <[email protected]> Co-authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Antoine Pitrou <[email protected]>
1 parent 686971e commit 4d625b3

File tree

9 files changed

+408
-248
lines changed

9 files changed

+408
-248
lines changed

cpp/src/arrow/acero/hash_aggregate_test.cc

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4440,7 +4440,7 @@ TEST_P(GroupBy, PivotBasics) {
44404440
}
44414441
}
44424442

4443-
TEST_P(GroupBy, PivotAllKeyTypes) {
4443+
TEST_P(GroupBy, PivotBinaryKeyTypes) {
44444444
auto value_type = float32();
44454445
std::vector<std::string> table_json = {R"([
44464446
[1, "width", 10.5],
@@ -4462,6 +4462,49 @@ TEST_P(GroupBy, PivotAllKeyTypes) {
44624462
ARROW_SCOPED_TRACE("key_type = ", *key_type);
44634463
TestPivot(key_type, value_type, options, table_json, expected_json);
44644464
}
4465+
4466+
auto key_type = fixed_size_binary(3);
4467+
table_json = {R"([
4468+
[1, "wid", 10.5],
4469+
[2, "wid", 11.5]
4470+
])",
4471+
R"([
4472+
[2, "hei", 12.5],
4473+
[3, "wid", 13.5],
4474+
[1, "hei", 14.5]
4475+
])"};
4476+
expected_json = R"([
4477+
[1, {"hei": 14.5, "wid": 10.5} ],
4478+
[2, {"hei": 12.5, "wid": 11.5} ],
4479+
[3, {"hei": null, "wid": 13.5} ]
4480+
])";
4481+
options.key_names = {"hei", "wid"};
4482+
ARROW_SCOPED_TRACE("key_type = ", *key_type);
4483+
TestPivot(key_type, value_type, options, table_json, expected_json);
4484+
}
4485+
4486+
TEST_P(GroupBy, PivotIntegerKeyTypes) {
4487+
auto value_type = float32();
4488+
std::vector<std::string> table_json = {R"([
4489+
[1, 78, 10.5],
4490+
[2, 78, 11.5]
4491+
])",
4492+
R"([
4493+
[2, 56, 12.5],
4494+
[3, 78, 13.5],
4495+
[1, 56, 14.5]
4496+
])"};
4497+
std::string expected_json = R"([
4498+
[1, {"56": 14.5, "78": 10.5} ],
4499+
[2, {"56": 12.5, "78": 11.5} ],
4500+
[3, {"56": null, "78": 13.5} ]
4501+
])";
4502+
PivotWiderOptions options(/*key_names=*/{"56", "78"});
4503+
4504+
for (const auto& key_type : IntTypes()) {
4505+
ARROW_SCOPED_TRACE("key_type = ", *key_type);
4506+
TestPivot(key_type, value_type, options, table_json, expected_json);
4507+
}
44654508
}
44664509

44674510
TEST_P(GroupBy, PivotNumericValues) {
@@ -4749,6 +4792,21 @@ TEST_P(GroupBy, PivotDuplicateKeys) {
47494792
RunPivot(key_type, value_type, options, table_json));
47504793
}
47514794

4795+
TEST_P(GroupBy, PivotInvalidKeys) {
4796+
// Integer key type, but key names cannot be converted to int
4797+
auto key_type = int32();
4798+
auto value_type = float32();
4799+
std::vector<std::string> table_json = {R"([])"};
4800+
PivotWiderOptions options(/*key_names=*/{"123", "width"});
4801+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4802+
Invalid, HasSubstr("Failed to parse string: 'width' as a scalar of type int32"),
4803+
RunPivot(key_type, value_type, options, table_json));
4804+
options.key_names = {"12.3", "45"};
4805+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4806+
Invalid, HasSubstr("Failed to parse string: '12.3' as a scalar of type int32"),
4807+
RunPivot(key_type, value_type, options, table_json));
4808+
}
4809+
47524810
TEST_P(GroupBy, PivotDuplicateValues) {
47534811
auto key_type = utf8();
47544812
auto value_type = float32();

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,10 @@ class ARROW_EXPORT TDigestOptions : public FunctionOptions {
202202
/// - The corresponding `Aggregate::target` must have two FieldRef elements;
203203
/// the first one points to the pivot key column, the second points to the
204204
/// pivoted data column.
205-
/// - The pivot key column must be string-like; its values will be matched
206-
/// against `key_names` in order to dispatch the pivoted data into the
207-
/// output.
205+
/// - The pivot key column can be string, binary or integer; its values will be
206+
/// matched against `key_names` in order to dispatch the pivoted data into
207+
/// the output. If the pivot key column is not string-like, the `key_names`
208+
/// will be cast to the pivot key type.
208209
///
209210
/// "pivot_wider" example
210211
/// ---------------------

cpp/src/arrow/compute/exec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ struct ExecValue {
276276
ArraySpan array = {};
277277
const Scalar* scalar = NULLPTR;
278278

279-
ExecValue(Scalar* scalar) // NOLINT implicit conversion
279+
ExecValue(const Scalar* scalar) // NOLINT implicit conversion
280280
: scalar(scalar) {}
281281

282282
ExecValue(ArraySpan array) // NOLINT implicit conversion

cpp/src/arrow/compute/kernels/aggregate_pivot.cc

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "arrow/scalar.h"
2323
#include "arrow/util/bit_run_reader.h"
2424
#include "arrow/util/logging.h"
25+
#include "arrow/visit_data_inline.h"
2526

2627
namespace arrow::compute::internal {
2728
namespace {
@@ -30,7 +31,8 @@ using arrow::internal::VisitSetBitRunsVoid;
3031
using arrow::util::span;
3132

3233
struct PivotImpl : public ScalarAggregator {
33-
Status Init(const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types) {
34+
Status Init(const PivotWiderOptions& options, const std::vector<TypeHolder>& in_types,
35+
ExecContext* ctx) {
3436
options_ = &options;
3537
key_type_ = in_types[0].GetSharedPtr();
3638
auto value_type = in_types[1].GetSharedPtr();
@@ -42,47 +44,57 @@ struct PivotImpl : public ScalarAggregator {
4244
values_.push_back(MakeNullScalar(value_type));
4345
}
4446
out_type_ = struct_(std::move(fields));
45-
ARROW_ASSIGN_OR_RAISE(key_mapper_, PivotWiderKeyMapper::Make(*key_type_, options_));
47+
ARROW_ASSIGN_OR_RAISE(key_mapper_,
48+
PivotWiderKeyMapper::Make(*key_type_, options_, ctx));
4649
return Status::OK();
4750
}
4851

4952
Status Consume(KernelContext*, const ExecSpan& batch) override {
5053
DCHECK_EQ(batch.num_values(), 2);
5154
if (batch[0].is_array()) {
52-
ARROW_ASSIGN_OR_RAISE(span<const PivotWiderKeyIndex> keys,
53-
key_mapper_->MapKeys(batch[0].array));
55+
ARROW_ASSIGN_OR_RAISE(auto keys_array, key_mapper_->MapKeys(batch[0].array));
56+
DCHECK_EQ(keys_array->type->id(), Type::UINT32);
57+
ArraySpan keys_span(*keys_array);
5458
if (batch[1].is_array()) {
5559
// Array keys, array values
5660
auto values = batch[1].array.ToArray();
57-
for (int64_t i = 0; i < batch.length; ++i) {
58-
PivotWiderKeyIndex key = keys[i];
59-
if (key != kNullPivotKey && !values->IsNull(i)) {
60-
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
61-
return DuplicateValue();
62-
}
63-
ARROW_ASSIGN_OR_RAISE(values_[key], values->GetScalar(i));
64-
DCHECK(values_[key]->is_valid);
65-
}
66-
}
61+
int64_t i = 0;
62+
RETURN_NOT_OK(VisitArraySpanInline<UInt32Type>(
63+
keys_span,
64+
[&](uint32_t key) {
65+
if (!values->IsNull(i)) {
66+
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
67+
return DuplicateValue();
68+
}
69+
ARROW_ASSIGN_OR_RAISE(values_[key], values->GetScalar(i));
70+
}
71+
++i;
72+
return Status::OK();
73+
},
74+
[&]() {
75+
++i;
76+
return Status::OK();
77+
}));
6778
} else {
6879
// Array keys, scalar value
6980
const Scalar* value = batch[1].scalar;
7081
if (value->is_valid) {
71-
for (int64_t i = 0; i < batch.length; ++i) {
72-
PivotWiderKeyIndex key = keys[i];
73-
if (key != kNullPivotKey) {
74-
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
75-
return DuplicateValue();
76-
}
77-
values_[key] = value->GetSharedPtr();
78-
}
79-
}
82+
RETURN_NOT_OK(VisitArraySpanInline<UInt32Type>(
83+
keys_span,
84+
[&](uint32_t key) {
85+
if (ARROW_PREDICT_FALSE(values_[key]->is_valid)) {
86+
return DuplicateValue();
87+
}
88+
values_[key] = value->GetSharedPtr();
89+
return Status::OK();
90+
},
91+
[] { return Status::OK(); }));
8092
}
8193
}
8294
} else {
83-
ARROW_ASSIGN_OR_RAISE(PivotWiderKeyIndex key,
84-
key_mapper_->MapKey(*batch[0].scalar));
85-
if (key != kNullPivotKey) {
95+
ARROW_ASSIGN_OR_RAISE(auto maybe_key, key_mapper_->MapKey(*batch[0].scalar));
96+
if (maybe_key.has_value()) {
97+
PivotWiderKeyIndex key = maybe_key.value();
8698
if (batch[1].is_array()) {
8799
// Scalar key, array values
88100
auto values = batch[1].array.ToArray();
@@ -145,10 +157,8 @@ struct PivotImpl : public ScalarAggregator {
145157
Result<std::unique_ptr<KernelState>> PivotInit(KernelContext* ctx,
146158
const KernelInitArgs& args) {
147159
const auto& options = checked_cast<const PivotWiderOptions&>(*args.options);
148-
DCHECK_EQ(args.inputs.size(), 2);
149-
DCHECK(is_base_binary_like(args.inputs[0].id()));
150160
auto state = std::make_unique<PivotImpl>();
151-
RETURN_NOT_OK(state->Init(options, args.inputs));
161+
RETURN_NOT_OK(state->Init(options, args.inputs, ctx->exec_context()));
152162
// GH-45718: This can be simplified once we drop the R openSUSE155 crossbow
153163
// job
154164
// R build with openSUSE155 requires an explicit shared_ptr construction
@@ -167,6 +177,8 @@ const FunctionDoc pivot_doc{
167177
"is emitted. If a pivot key doesn't appear, null is emitted.\n"
168178
"If more than one non-null value is encountered for a given pivot key,\n"
169179
"Invalid is raised.\n"
180+
"The pivot key column can be string, binary or integer. The `key_names`\n"
181+
"will be cast to the pivot key column type for matching.\n"
170182
"Behavior of unexpected pivot keys is controlled by `unexpected_key_behavior`\n"
171183
"in PivotWiderOptions."),
172184
{"pivot_keys", "pivot_values"},
@@ -179,12 +191,19 @@ void RegisterScalarAggregatePivot(FunctionRegistry* registry) {
179191

180192
auto func = std::make_shared<ScalarAggregateFunction>(
181193
"pivot_wider", Arity::Binary(), pivot_doc, &default_pivot_options);
182-
183-
for (auto key_type : BaseBinaryTypes()) {
184-
auto sig = KernelSignature::Make({key_type->id(), InputType::Any()},
194+
auto add_kernel = [&](InputType key_type) {
195+
auto sig = KernelSignature::Make({key_type, InputType::Any()},
185196
OutputType(ResolveOutputType));
186197
AddAggKernel(std::move(sig), PivotInit, func.get());
198+
};
199+
200+
for (const auto& key_type : BaseBinaryTypes()) {
201+
add_kernel(key_type->id());
202+
}
203+
for (const auto& key_type : IntTypes()) {
204+
add_kernel(key_type->id());
187205
}
206+
add_kernel(Type::FIXED_SIZE_BINARY);
188207
DCHECK_OK(registry->AddFunction(std::move(func)));
189208
}
190209

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4504,10 +4504,9 @@ TEST_F(TestPivotKernel, Basics) {
45044504
PivotWiderOptions(/*key_names=*/{"height", "width"}));
45054505
}
45064506

4507-
TEST_F(TestPivotKernel, AllKeyTypes) {
4507+
TEST_F(TestPivotKernel, BinaryKeyTypes) {
4508+
auto value_type = float32();
45084509
for (auto key_type : BaseBinaryTypes()) {
4509-
auto value_type = float32();
4510-
45114510
auto keys = ArrayFromJSON(key_type, R"(["width", "height"])");
45124511
auto values = ArrayFromJSON(value_type, "[10.5, 11.5]");
45134512
auto expected =
@@ -4516,6 +4515,25 @@ TEST_F(TestPivotKernel, AllKeyTypes) {
45164515
AssertPivot(keys, values, *expected,
45174516
PivotWiderOptions(/*key_names=*/{"height", "width"}));
45184517
}
4518+
auto key_type = fixed_size_binary(3);
4519+
auto keys = ArrayFromJSON(key_type, R"(["wid", "hei"])");
4520+
auto values = ArrayFromJSON(value_type, "[10.5, 11.5]");
4521+
auto expected = ScalarFromJSON(
4522+
struct_({field("hei", value_type), field("wid", value_type)}), "[11.5, 10.5]");
4523+
AssertPivot(keys, values, *expected, PivotWiderOptions(/*key_names=*/{"hei", "wid"}));
4524+
}
4525+
4526+
TEST_F(TestPivotKernel, IntegerKeyTypes) {
4527+
// It is possible to use an integer key column, while passing its string equivalent
4528+
// in PivotWiderOptions::key_names.
4529+
auto value_type = float32();
4530+
for (auto key_type : IntTypes()) {
4531+
auto keys = ArrayFromJSON(key_type, "[34, 12]");
4532+
auto values = ArrayFromJSON(value_type, "[10.5, 11.5]");
4533+
auto expected = ScalarFromJSON(
4534+
struct_({field("12", value_type), field("34", value_type)}), "[11.5, 10.5]");
4535+
AssertPivot(keys, values, *expected, PivotWiderOptions(/*key_names=*/{"12", "34"}));
4536+
}
45194537
}
45204538

45214539
TEST_F(TestPivotKernel, Numbers) {
@@ -4724,6 +4742,24 @@ TEST_F(TestPivotKernel, DuplicateKeyNames) {
47244742
CallFunction("pivot_wider", {keys, values}, &options));
47254743
}
47264744

4745+
TEST_F(TestPivotKernel, InvalidKeyName) {
4746+
auto key_type = int32();
4747+
auto value_type = float32();
4748+
4749+
auto keys = ArrayFromJSON(key_type, "[]");
4750+
auto values = ArrayFromJSON(value_type, "[]");
4751+
auto options = PivotWiderOptions(/*key_names=*/{"123", "width"});
4752+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4753+
Invalid,
4754+
::testing::HasSubstr("Failed to parse string: 'width' as a scalar of type int32"),
4755+
CallFunction("pivot_wider", {keys, values}, &options));
4756+
options.key_names = {"12.3", "45"};
4757+
EXPECT_RAISES_WITH_MESSAGE_THAT(
4758+
Invalid,
4759+
::testing::HasSubstr("Failed to parse string: '12.3' as a scalar of type int32"),
4760+
CallFunction("pivot_wider", {keys, values}, &options));
4761+
}
4762+
47274763
TEST_F(TestPivotKernel, DuplicateValues) {
47284764
auto key_type = utf8();
47294765
auto value_type = float32();

0 commit comments

Comments
 (0)