Skip to content

Commit ac22014

Browse files
Support cudf-polars str.head and str.tail (rapidsai#19115)
Closes rapidsai#19031 Closes rapidsai#18995 Authors: - https://github.com/brandon-b-miller Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: rapidsai#19115
1 parent b90691f commit ac22014

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

python/cudf_polars/cudf_polars/dsl/expressions/string.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def _validate_input(self) -> None:
113113
StringFunction.Name.ConcatVertical,
114114
StringFunction.Name.Contains,
115115
StringFunction.Name.EndsWith,
116+
StringFunction.Name.Head,
116117
StringFunction.Name.Lowercase,
117118
StringFunction.Name.Replace,
118119
StringFunction.Name.ReplaceMany,
@@ -123,6 +124,7 @@ def _validate_input(self) -> None:
123124
StringFunction.Name.StripCharsStart,
124125
StringFunction.Name.StripCharsEnd,
125126
StringFunction.Name.Uppercase,
127+
StringFunction.Name.Tail,
126128
):
127129
raise NotImplementedError(f"String function {self.name!r}")
128130
if self.name is StringFunction.Name.Contains:
@@ -283,6 +285,60 @@ def do_evaluate(
283285
side = plc.strings.SideType.BOTH
284286
return Column(plc.strings.strip.strip(column.obj, side, chars.obj_scalar))
285287

288+
elif self.name is StringFunction.Name.Tail:
289+
column = self.children[0].evaluate(df, context=context)
290+
291+
assert isinstance(self.children[1], Literal)
292+
if self.children[1].value is None:
293+
return Column(
294+
plc.Column.from_scalar(
295+
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
296+
column.size,
297+
)
298+
)
299+
elif self.children[1].value == 0:
300+
result = plc.Column.from_scalar(
301+
plc.Scalar.from_py("", plc.DataType(plc.TypeId.STRING)),
302+
column.size,
303+
)
304+
if column.obj.null_mask():
305+
result = result.with_mask(
306+
column.obj.null_mask(), column.obj.null_count()
307+
)
308+
return Column(result)
309+
310+
else:
311+
start = -(self.children[1].value)
312+
end = 2**31 - 1
313+
return Column(
314+
plc.strings.slice.slice_strings(
315+
column.obj,
316+
plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
317+
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
318+
None,
319+
)
320+
)
321+
elif self.name is StringFunction.Name.Head:
322+
column = self.children[0].evaluate(df, context=context)
323+
324+
assert isinstance(self.children[1], Literal)
325+
326+
end = self.children[1].value
327+
if end is None:
328+
return Column(
329+
plc.Column.from_scalar(
330+
plc.Scalar.from_py(None, plc.DataType(plc.TypeId.STRING)),
331+
column.size,
332+
)
333+
)
334+
return Column(
335+
plc.strings.slice.slice_strings(
336+
column.obj,
337+
plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT32)),
338+
plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
339+
)
340+
)
341+
286342
columns = [child.evaluate(df, context=context) for child in self.children]
287343
if self.name is StringFunction.Name.Lowercase:
288344
(column,) = columns
@@ -352,6 +408,7 @@ def do_evaluate(
352408
return Column(
353409
plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj)
354410
)
411+
355412
raise NotImplementedError(
356413
f"StringFunction {self.name}"
357414
) # pragma: no cover; handled by init raising

python/cudf_polars/tests/expressions/test_stringfunction.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,3 +470,15 @@ def test_string_to_numeric_invalid(numeric_type):
470470
def test_string_join(ldf, ignore_nulls, delimiter):
471471
q = ldf.select(pl.col("a").str.join(delimiter, ignore_nulls=ignore_nulls))
472472
assert_gpu_result_equal(q)
473+
474+
475+
@pytest.mark.parametrize("tail", [1, 2, 999, -1, 0, None])
476+
def test_string_tail(ldf, tail):
477+
q = ldf.select(pl.col("a").str.tail(tail))
478+
assert_gpu_result_equal(q)
479+
480+
481+
@pytest.mark.parametrize("head", [1, 2, 999, -1, 0, None])
482+
def test_string_head(ldf, head):
483+
q = ldf.select(pl.col("a").str.head(head))
484+
assert_gpu_result_equal(q)

0 commit comments

Comments
 (0)