Skip to content

Commit be844b7

Browse files
committed
[C++][Python] Support Python-like slicing in list_slice kernel
1 parent 3cbc27a commit be844b7

File tree

3 files changed

+58
-19
lines changed

3 files changed

+58
-19
lines changed

cpp/src/arrow/compute/kernels/scalar_nested.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,9 @@ struct ListSlice {
183183
const auto* list_type = checked_cast<const BaseListType*>(list_array.type);
184184

185185
// Pre-conditions
186-
if (opts.start < 0 || (opts.stop.has_value() && opts.start >= opts.stop.value())) {
187-
// TODO(ARROW-18281): support start == stop which should give empty lists
186+
if (opts.start < 0 || (opts.stop.has_value() && opts.start > opts.stop.value())) {
188187
return Status::Invalid("`start`(", opts.start,
189-
") should be greater than 0 and smaller than `stop`(",
188+
") should be >= 0 and not greater than `stop`(",
190189
ToString(opts.stop), ")");
191190
}
192191
if (opts.step < 1) {

cpp/src/arrow/compute/kernels/scalar_nested_test.cc

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,32 @@ TEST(TestScalarNested, ListSliceOutputEqualsInputType) {
306306
}
307307
}
308308

309+
TEST(TestScalarNested, ListSliceEmptyLists) {
310+
// start == stop should return empty lists
311+
auto input = ArrayFromJSON(list(int32()), "[[1, 2, 3], [4, 5], null]");
312+
ListSliceOptions args(/*start=*/0, /*stop=*/0, /*step=*/1);
313+
auto expected = ArrayFromJSON(list(int32()), "[[], [], null]");
314+
CheckScalarUnary("list_slice", input, expected, &args);
315+
316+
// Different start position
317+
args.start = 1;
318+
args.stop = 1;
319+
CheckScalarUnary("list_slice", input, expected, &args);
320+
321+
// Large list
322+
auto input_large = ArrayFromJSON(large_list(int32()), "[[1, 2, 3], [4, 5]]");
323+
args.start = 0;
324+
args.stop = 0;
325+
auto expected_large = ArrayFromJSON(large_list(int32()), "[[], []]");
326+
CheckScalarUnary("list_slice", input_large, expected_large, &args);
327+
328+
// Fixed size list -> fixed size list[0]
329+
auto input_fixed = ArrayFromJSON(fixed_size_list(int32(), 3), "[[1, 2, 3], [4, 5, 6]]");
330+
args.return_fixed_size_list = true;
331+
auto expected_fixed = ArrayFromJSON(fixed_size_list(int32(), 0), "[[], []]");
332+
CheckScalarUnary("list_slice", input_fixed, expected_fixed, &args);
333+
}
334+
309335
TEST(TestScalarNested, ListSliceBadParameters) {
310336
auto input = ArrayFromJSON(list(int32()), "[[1]]");
311337

@@ -314,23 +340,14 @@ TEST(TestScalarNested, ListSliceBadParameters) {
314340
/*return_fixed_size_list=*/true);
315341
EXPECT_RAISES_WITH_MESSAGE_THAT(
316342
Invalid,
317-
::testing::HasSubstr(
318-
"`start`(-1) should be greater than 0 and smaller than `stop`(1)"),
343+
::testing::HasSubstr("`start`(-1) should be >= 0 and not greater than `stop`(1)"),
319344
CallFunction("list_slice", {input}, &args));
320345
// start greater than stop
321346
args.start = 1;
322347
args.stop = 0;
323348
EXPECT_RAISES_WITH_MESSAGE_THAT(
324349
Invalid,
325-
::testing::HasSubstr(
326-
"`start`(1) should be greater than 0 and smaller than `stop`(0)"),
327-
CallFunction("list_slice", {input}, &args));
328-
// start same as stop
329-
args.stop = args.start;
330-
EXPECT_RAISES_WITH_MESSAGE_THAT(
331-
Invalid,
332-
::testing::HasSubstr(
333-
"`start`(1) should be greater than 0 and smaller than `stop`(1)"),
350+
::testing::HasSubstr("`start`(1) should be >= 0 and not greater than `stop`(0)"),
334351
CallFunction("list_slice", {input}, &args));
335352
// stop not set and FixedSizeList requested with variable sized input
336353
args.stop = std::nullopt;

python/pyarrow/tests/test_compute.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3946,16 +3946,12 @@ def test_list_slice_field_names_retained(return_fixed_size, type):
39463946

39473947
def test_list_slice_bad_parameters():
39483948
arr = pa.array([[1]], pa.list_(pa.int8(), 1))
3949-
msg = r"`start`(.*) should be greater than 0 and smaller than `stop`(.*)"
3949+
msg = r"`start`(.*) should be >= 0 and not greater than `stop`(.*)"
39503950
with pytest.raises(pa.ArrowInvalid, match=msg):
39513951
pc.list_slice(arr, -1, 1) # negative start?
39523952
with pytest.raises(pa.ArrowInvalid, match=msg):
39533953
pc.list_slice(arr, 2, 1) # start > stop?
39543954

3955-
# TODO(ARROW-18281): start==stop -> empty lists
3956-
with pytest.raises(pa.ArrowInvalid, match=msg):
3957-
pc.list_slice(arr, 0, 0) # start == stop?
3958-
39593955
# Step not >= 1
39603956
msg = "`step` must be >= 1, got: "
39613957
with pytest.raises(pa.ArrowInvalid, match=msg + "0"):
@@ -3964,6 +3960,33 @@ def test_list_slice_bad_parameters():
39643960
pc.list_slice(arr, 0, 1, step=-1)
39653961

39663962

3963+
def test_list_slice_empty_lists():
3964+
# Test start == stop should return empty lists
3965+
arr = pa.array([[1, 2, 3], [4, 5, None], [6, None, None], None])
3966+
result = pc.list_slice(arr, 0, 0)
3967+
expected = pa.array([[], [], [], None], type=pa.list_(pa.int64()))
3968+
assert result.equals(expected)
3969+
3970+
# Test with different start positions
3971+
result = pc.list_slice(arr, 1, 1)
3972+
assert result.equals(expected)
3973+
3974+
result = pc.list_slice(arr, 2, 2)
3975+
assert result.equals(expected)
3976+
3977+
# Test with large_list
3978+
arr_large = pa.array([[1, 2, 3], [4, 5, None]], pa.large_list(pa.int64()))
3979+
result = pc.list_slice(arr_large, 0, 0)
3980+
expected_large = pa.array([[], []], pa.large_list(pa.int64()))
3981+
assert result.equals(expected_large)
3982+
3983+
# Test with fixed_size_list -> output is fixed_size_list[0]
3984+
arr_fixed = pa.array([[1, 2, 3], [4, 5, 6]], pa.list_(pa.int64(), 3))
3985+
result = pc.list_slice(arr_fixed, 0, 0)
3986+
expected_fixed = pa.array([[], []], pa.list_(pa.int64(), 0))
3987+
assert result.equals(expected_fixed)
3988+
3989+
39673990
def check_run_end_encode_decode(value_type, run_end_encode_opts=None):
39683991
values = [1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3]
39693992
arr = pa.array(values, type=value_type)

0 commit comments

Comments
 (0)