Skip to content

Commit fcb5f96

Browse files
committed
Making a variety of adjustments in wrappers and unit tests to account for the switch from string to string_view as default
1 parent 73cfddf commit fcb5f96

File tree

4 files changed

+63
-27
lines changed

4 files changed

+63
-27
lines changed

python/datafusion/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def fill_null(self, value: Any | Expr | None = None) -> Expr:
421421
_to_pyarrow_types = {
422422
float: pa.float64(),
423423
int: pa.int64(),
424-
str: pa.string_view(),
424+
str: pa.string(),
425425
bool: pa.bool_(),
426426
}
427427

python/datafusion/functions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def decode(input: Expr, encoding: Expr) -> Expr:
295295

296296
def array_to_string(expr: Expr, delimiter: Expr) -> Expr:
297297
"""Converts each element to its text representation."""
298-
return Expr(f.array_to_string(expr.expr, delimiter.expr))
298+
return Expr(f.array_to_string(expr.expr, delimiter.expr.cast(pa.string())))
299299

300300

301301
def array_join(expr: Expr, delimiter: Expr) -> Expr:
@@ -924,7 +924,7 @@ def to_timestamp(arg: Expr, *formatters: Expr) -> Expr:
924924
return f.to_timestamp(arg.expr)
925925

926926
formatters = [f.expr for f in formatters]
927-
return Expr(f.to_timestamp(arg.expr, *formatters))
927+
return Expr(f.to_timestamp(arg.expr.cast(pa.string()), *formatters))
928928

929929

930930
def to_timestamp_millis(arg: Expr, *formatters: Expr) -> Expr:
@@ -1065,7 +1065,10 @@ def struct(*args: Expr) -> Expr:
10651065

10661066
def named_struct(name_pairs: list[tuple[str, Expr]]) -> Expr:
10671067
"""Returns a struct with the given names and arguments pairs."""
1068-
name_pair_exprs = [[Expr.literal(pair[0]), pair[1]] for pair in name_pairs]
1068+
name_pair_exprs = [
1069+
[Expr.literal(pa.scalar(pair[0], type=pa.string())), pair[1]]
1070+
for pair in name_pairs
1071+
]
10691072

10701073
# flatten
10711074
name_pairs = [x.expr for xs in name_pair_exprs for x in xs]
@@ -1422,7 +1425,9 @@ def array_sort(array: Expr, descending: bool = False, null_first: bool = False)
14221425
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
14231426
return Expr(
14241427
f.array_sort(
1425-
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
1428+
array.expr,
1429+
Expr.literal(pa.scalar(desc, type=pa.string())).expr,
1430+
Expr.literal(pa.scalar(nulls_first, type=pa.string())).expr,
14261431
)
14271432
)
14281433

python/tests/test_expr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def test_relational_expr(test_ctx):
130130
ctx = SessionContext()
131131

132132
batch = pa.RecordBatch.from_arrays(
133-
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
133+
[
134+
pa.array([1, 2, 3]),
135+
pa.array(["alpha", "beta", "gamma"], type=pa.string_view()),
136+
],
134137
names=["a", "b"],
135138
)
136139
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -145,7 +148,8 @@ def test_relational_expr(test_ctx):
145148
assert df.filter(col("b") == "beta").count() == 1
146149
assert df.filter(col("b") != "beta").count() == 2
147150

148-
assert df.filter(col("a") == "beta").count() == 0
151+
with pytest.raises(Exception):
152+
df.filter(col("a") == "beta").count()
149153

150154

151155
def test_expr_to_variant():

python/tests/test_functions.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def df():
3434
# create a RecordBatch and a new DataFrame from it
3535
batch = pa.RecordBatch.from_arrays(
3636
[
37-
pa.array(["Hello", "World", "!"]),
37+
pa.array(["Hello", "World", "!"], type=pa.string_view()),
3838
pa.array([4, 5, 6]),
39-
pa.array(["hello ", " world ", " !"]),
39+
pa.array(["hello ", " world ", " !"], type=pa.string_view()),
4040
pa.array(
4141
[
4242
datetime(2022, 12, 31),
@@ -88,16 +88,18 @@ def test_literal(df):
8888
assert len(result) == 1
8989
result = result[0]
9090
assert result.column(0) == pa.array([1] * 3)
91-
assert result.column(1) == pa.array(["1"] * 3)
92-
assert result.column(2) == pa.array(["OK"] * 3)
91+
assert result.column(1) == pa.array(["1"] * 3, type=pa.string_view())
92+
assert result.column(2) == pa.array(["OK"] * 3, type=pa.string_view())
9393
assert result.column(3) == pa.array([3.14] * 3)
9494
assert result.column(4) == pa.array([True] * 3)
9595
assert result.column(5) == pa.array([b"hello world"] * 3)
9696

9797

9898
def test_lit_arith(df):
9999
"""Test literals with arithmetic operations"""
100-
df = df.select(literal(1) + column("b"), f.concat(column("a"), literal("!")))
100+
df = df.select(
101+
literal(1) + column("b"), f.concat(column("a").cast(pa.string()), literal("!"))
102+
)
101103
result = df.collect()
102104
assert len(result) == 1
103105
result = result[0]
@@ -578,21 +580,33 @@ def test_array_function_obj_tests(stmt, py_expr):
578580
f.ascii(column("a")),
579581
pa.array([72, 87, 33], type=pa.int32()),
580582
), # H = 72; W = 87; ! = 33
581-
(f.bit_length(column("a")), pa.array([40, 40, 8], type=pa.int32())),
582-
(f.btrim(literal(" World ")), pa.array(["World", "World", "World"])),
583+
(
584+
f.bit_length(column("a").cast(pa.string())),
585+
pa.array([40, 40, 8], type=pa.int32()),
586+
),
587+
(
588+
f.btrim(literal(" World ")),
589+
pa.array(["World", "World", "World"], type=pa.string_view()),
590+
),
583591
(f.character_length(column("a")), pa.array([5, 5, 1], type=pa.int32())),
584592
(f.chr(literal(68)), pa.array(["D", "D", "D"])),
585593
(
586594
f.concat_ws("-", column("a"), literal("test")),
587595
pa.array(["Hello-test", "World-test", "!-test"]),
588596
),
589-
(f.concat(column("a"), literal("?")), pa.array(["Hello?", "World?", "!?"])),
597+
(
598+
f.concat(column("a").cast(pa.string()), literal("?")),
599+
pa.array(["Hello?", "World?", "!?"]),
600+
),
590601
(f.initcap(column("c")), pa.array(["Hello ", " World ", " !"])),
591602
(f.left(column("a"), literal(3)), pa.array(["Hel", "Wor", "!"])),
592603
(f.length(column("c")), pa.array([6, 7, 2], type=pa.int32())),
593604
(f.lower(column("a")), pa.array(["hello", "world", "!"])),
594605
(f.lpad(column("a"), literal(7)), pa.array([" Hello", " World", " !"])),
595-
(f.ltrim(column("c")), pa.array(["hello ", "world ", "!"])),
606+
(
607+
f.ltrim(column("c")),
608+
pa.array(["hello ", "world ", "!"], type=pa.string_view()),
609+
),
596610
(
597611
f.md5(column("a")),
598612
pa.array(
@@ -618,19 +632,25 @@ def test_array_function_obj_tests(stmt, py_expr):
618632
f.rpad(column("a"), literal(8)),
619633
pa.array(["Hello ", "World ", "! "]),
620634
),
621-
(f.rtrim(column("c")), pa.array(["hello", " world", " !"])),
635+
(
636+
f.rtrim(column("c")),
637+
pa.array(["hello", " world", " !"], type=pa.string_view()),
638+
),
622639
(
623640
f.split_part(column("a"), literal("l"), literal(1)),
624641
pa.array(["He", "Wor", "!"]),
625642
),
626643
(f.starts_with(column("a"), literal("Wor")), pa.array([False, True, False])),
627644
(f.strpos(column("a"), literal("o")), pa.array([5, 2, 0], type=pa.int32())),
628-
(f.substr(column("a"), literal(3)), pa.array(["llo", "rld", ""])),
645+
(
646+
f.substr(column("a"), literal(3)),
647+
pa.array(["llo", "rld", ""], type=pa.string_view()),
648+
),
629649
(
630650
f.translate(column("a"), literal("or"), literal("ld")),
631651
pa.array(["Helll", "Wldld", "!"]),
632652
),
633-
(f.trim(column("c")), pa.array(["hello", "world", "!"])),
653+
(f.trim(column("c")), pa.array(["hello", "world", "!"], type=pa.string_view())),
634654
(f.upper(column("c")), pa.array(["HELLO ", " WORLD ", " !"])),
635655
(f.ends_with(column("a"), literal("llo")), pa.array([True, False, False])),
636656
(
@@ -772,9 +792,9 @@ def test_temporal_functions(df):
772792
f.date_trunc(literal("month"), column("d")),
773793
f.datetrunc(literal("day"), column("d")),
774794
f.date_bin(
775-
literal("15 minutes"),
795+
literal("15 minutes").cast(pa.string()),
776796
column("d"),
777-
literal("2001-01-01 00:02:30"),
797+
literal("2001-01-01 00:02:30").cast(pa.string()),
778798
),
779799
f.from_unixtime(literal(1673383974)),
780800
f.to_timestamp(literal("2023-09-07 05:06:14.523952")),
@@ -836,8 +856,8 @@ def test_case(df):
836856
result = df.collect()
837857
result = result[0]
838858
assert result.column(0) == pa.array([10, 8, 8])
839-
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"])
840-
assert result.column(2) == pa.array(["Hola", "Mundo", None])
859+
assert result.column(1) == pa.array(["Hola", "Mundo", "!!"], type=pa.string_view())
860+
assert result.column(2) == pa.array(["Hola", "Mundo", None], type=pa.string_view())
841861

842862

843863
def test_when_with_no_base(df):
@@ -855,8 +875,10 @@ def test_when_with_no_base(df):
855875
result = df.collect()
856876
result = result[0]
857877
assert result.column(0) == pa.array([4, 5, 6])
858-
assert result.column(1) == pa.array(["too small", "just right", "too big"])
859-
assert result.column(2) == pa.array(["Hello", None, None])
878+
assert result.column(1) == pa.array(
879+
["too small", "just right", "too big"], type=pa.string_view()
880+
)
881+
assert result.column(2) == pa.array(["Hello", None, None], type=pa.string_view())
860882

861883

862884
def test_regr_funcs_sql(df):
@@ -999,8 +1021,13 @@ def test_regr_funcs_df(func, expected):
9991021

10001022
def test_binary_string_functions(df):
10011023
df = df.select(
1002-
f.encode(column("a"), literal("base64")),
1003-
f.decode(f.encode(column("a"), literal("base64")), literal("base64")),
1024+
f.encode(column("a").cast(pa.string()), literal("base64").cast(pa.string())),
1025+
f.decode(
1026+
f.encode(
1027+
column("a").cast(pa.string()), literal("base64").cast(pa.string())
1028+
),
1029+
literal("base64").cast(pa.string()),
1030+
),
10041031
)
10051032
result = df.collect()
10061033
assert len(result) == 1

0 commit comments

Comments
 (0)