Skip to content

Commit ba53bd1

Browse files
committed
feat: add utf8_literal function to create UTF8 literal expressions
1 parent 193d21c commit ba53bd1

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

python/datafusion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def literal(value):
107107
return Expr.literal(value)
108108

109109

110+
def utf8_literal(value):
111+
"""Create a UTF8 literal expression."""
112+
return Expr.utf8_literal(value)
113+
114+
110115
def lit(value):
111116
"""Create a literal expression."""
112117
return Expr.literal(value)

python/datafusion/expr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,12 @@ def literal(value: Any) -> Expr:
380380
value = pa.scalar(value)
381381
return Expr(expr_internal.Expr.literal(value))
382382

383+
@staticmethod
384+
def utf8_literal(value: str) -> Expr:
385+
"""Creates a new expression representing a UTF8 literal value."""
386+
value = pa.scalar(value, type=pa.string())
387+
return Expr(expr_internal.Expr.literal(value))
388+
383389
@staticmethod
384390
def column(value: str) -> Expr:
385391
"""Creates a new expression representing a column."""

python/tests/test_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
from datafusion import SessionContext, column
2525
from datafusion import functions as f
26-
from datafusion import literal
26+
from datafusion import literal, utf8_literal
27+
from datafusion import Expr
28+
from datafusion.expr import expr_internal
2729

2830
np.seterr(invalid="ignore")
2931

@@ -907,8 +909,8 @@ def test_temporal_functions(df):
907909

908910
def test_arrow_cast(df):
909911
df = df.select(
910-
f.arrow_cast(column("a"), literal("Float64")).alias("a_as_float"),
911-
f.arrow_cast(column("a"), literal("Int32")).alias("a_as_int"),
912+
f.arrow_cast(column("a"), utf8_literal("Float64")).alias("a_as_float"),
913+
f.arrow_cast(column("a"), utf8_literal("Int32")).alias("a_as_int"),
912914
)
913915
result = df.collect()
914916
assert len(result) == 1

0 commit comments

Comments
 (0)