Skip to content

Commit acaef41

Browse files
committed
1 parent 4a947b9 commit acaef41

File tree

5 files changed

+229
-70
lines changed

5 files changed

+229
-70
lines changed

duckdb/experimental/spark/sql/functions.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,6 +1851,30 @@ def isnotnull(col: "ColumnOrName") -> Column:
18511851
return Column(_to_column_expr(col).isnotnull())
18521852

18531853

1854+
def equal_null(col1: "ColumnOrName", col2: "ColumnOrName") -> Column:
1855+
"""
1856+
Returns same result as the EQUAL(=) operator for non-null operands,
1857+
but returns true if both are null, false if one of the them is null.
1858+
.. versionadded:: 3.5.0
1859+
Parameters
1860+
----------
1861+
col1 : :class:`~pyspark.sql.Column` or str
1862+
col2 : :class:`~pyspark.sql.Column` or str
1863+
Examples
1864+
--------
1865+
>>> df = spark.createDataFrame([(None, None,), (1, 9,)], ["a", "b"])
1866+
>>> df.select(equal_null(df.a, df.b).alias('r')).collect()
1867+
[Row(r=True), Row(r=False)]
1868+
"""
1869+
if isinstance(col1, str):
1870+
col1 = col(col1)
1871+
1872+
if isinstance(col2, str):
1873+
col2 = col(col2)
1874+
1875+
return nvl((col1 == col2) | ((col1.isNull() & col2.isNull())), lit(False))
1876+
1877+
18541878
def flatten(col: "ColumnOrName") -> Column:
18551879
"""
18561880
Collection function: creates a single array from an array of arrays.
@@ -2157,6 +2181,33 @@ def e() -> Column:
21572181
return lit(2.718281828459045)
21582182

21592183

2184+
def negative(col: "ColumnOrName") -> Column:
2185+
"""
2186+
Returns the negative value.
2187+
.. versionadded:: 3.5.0
2188+
Parameters
2189+
----------
2190+
col : :class:`~pyspark.sql.Column` or str
2191+
column to calculate negative value for.
2192+
Returns
2193+
-------
2194+
:class:`~pyspark.sql.Column`
2195+
negative value.
2196+
Examples
2197+
--------
2198+
>>> import pyspark.sql.functions as sf
2199+
>>> spark.range(3).select(sf.negative("id")).show()
2200+
+------------+
2201+
|negative(id)|
2202+
+------------+
2203+
| 0|
2204+
| -1|
2205+
| -2|
2206+
+------------+
2207+
"""
2208+
return abs(col) * -1
2209+
2210+
21602211
def pi() -> Column:
21612212
"""Returns Pi.
21622213
@@ -3774,6 +3825,53 @@ def datepart(field: "ColumnOrName", source: "ColumnOrName") -> Column:
37743825
return date_part(field, source)
37753826

37763827

3828+
def date_diff(end: "ColumnOrName", start: "ColumnOrName") -> Column:
3829+
"""
3830+
Returns the number of days from `start` to `end`.
3831+
3832+
.. versionadded:: 3.5.0
3833+
3834+
Parameters
3835+
----------
3836+
end : :class:`~pyspark.sql.Column` or column name
3837+
to date column to work on.
3838+
start : :class:`~pyspark.sql.Column` or column name
3839+
from date column to work on.
3840+
3841+
Returns
3842+
-------
3843+
:class:`~pyspark.sql.Column`
3844+
difference in days between two dates.
3845+
3846+
See Also
3847+
--------
3848+
:meth:`pyspark.sql.functions.dateadd`
3849+
:meth:`pyspark.sql.functions.date_add`
3850+
:meth:`pyspark.sql.functions.date_sub`
3851+
:meth:`pyspark.sql.functions.datediff`
3852+
:meth:`pyspark.sql.functions.timestamp_diff`
3853+
3854+
Examples
3855+
--------
3856+
>>> import pyspark.sql.functions as sf
3857+
>>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2'])
3858+
>>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show()
3859+
+----------+----------+-----------------+
3860+
| d1| d2|date_diff(d1, d2)|
3861+
+----------+----------+-----------------+
3862+
|2015-04-08|2015-05-10| -32|
3863+
+----------+----------+-----------------+
3864+
3865+
>>> df.select('*', sf.date_diff(sf.col('d1').cast('DATE'), sf.col('d2').cast('DATE'))).show()
3866+
+----------+----------+-----------------+
3867+
| d1| d2|date_diff(d2, d1)|
3868+
+----------+----------+-----------------+
3869+
|2015-04-08|2015-05-10| 32|
3870+
+----------+----------+-----------------+
3871+
"""
3872+
return _invoke_function_over_columns("date_diff", lit("day"), end, start)
3873+
3874+
37773875
def year(col: "ColumnOrName") -> Column:
37783876
"""
37793877
Extract the year of a given date/timestamp as integer.
@@ -5685,6 +5783,31 @@ def to_timestamp_ntz(
56855783
return _to_date_or_timestamp(timestamp, _types.TimestampNTZType(), format)
56865784

56875785

5786+
def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> Column:
5787+
"""
5788+
Parses the `col` with the `format` to a timestamp. The function always
5789+
returns null on an invalid input with/without ANSI SQL mode enabled. The result data type is
5790+
consistent with the value of configuration `spark.sql.timestampType`.
5791+
.. versionadded:: 3.5.0
5792+
Parameters
5793+
----------
5794+
col : :class:`~pyspark.sql.Column` or str
5795+
column values to convert.
5796+
format: str, optional
5797+
format to use to convert timestamp values.
5798+
Examples
5799+
--------
5800+
>>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
5801+
>>> df.select(try_to_timestamp(df.t).alias('dt')).collect()
5802+
[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
5803+
>>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect()
5804+
[Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
5805+
"""
5806+
if format is None:
5807+
format = lit(['%Y-%m-%d', '%Y-%m-%d %H:%M:%S'])
5808+
5809+
return _invoke_function_over_columns("try_strptime", col, format)
5810+
56885811
def substr(
56895812
str: "ColumnOrName", pos: "ColumnOrName", len: Optional["ColumnOrName"] = None
56905813
) -> Column:

tests/fast/spark/test_spark_functions_array.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import platform
33

44
_ = pytest.importorskip("duckdb.experimental.spark")
5-
from spark_namespace.sql import functions as F
5+
from spark_namespace.sql import functions as sf
66
from spark_namespace.sql.types import Row
77
from spark_namespace import USE_ACTUAL_SPARK
88

@@ -19,7 +19,7 @@ def test_array_distinct(self, spark):
1919
([2, 4, 5], 3),
2020
]
2121
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
22-
df = df.withColumn("distinct_values", F.array_distinct(F.col("firstColumn")))
22+
df = df.withColumn("distinct_values", sf.array_distinct(sf.col("firstColumn")))
2323
res = df.select("distinct_values").collect()
2424
# Output order can vary across platforms which is why we sort it first
2525
assert len(res) == 2
@@ -31,7 +31,7 @@ def test_array_intersect(self, spark):
3131
(["b", "a", "c"], ["c", "d", "a", "f"]),
3232
]
3333
df = spark.createDataFrame(data, ["c1", "c2"])
34-
df = df.withColumn("intersect_values", F.array_intersect(F.col("c1"), F.col("c2")))
34+
df = df.withColumn("intersect_values", sf.array_intersect(sf.col("c1"), sf.col("c2")))
3535
res = df.select("intersect_values").collect()
3636
# Output order can vary across platforms which is why we sort it first
3737
assert len(res) == 1
@@ -42,7 +42,7 @@ def test_array_union(self, spark):
4242
(["b", "a", "c"], ["c", "d", "a", "f"]),
4343
]
4444
df = spark.createDataFrame(data, ["c1", "c2"])
45-
df = df.withColumn("union_values", F.array_union(F.col("c1"), F.col("c2")))
45+
df = df.withColumn("union_values", sf.array_union(sf.col("c1"), sf.col("c2")))
4646
res = df.select("union_values").collect()
4747
# Output order can vary across platforms which is why we sort it first
4848
assert len(res) == 1
@@ -54,7 +54,7 @@ def test_array_max(self, spark):
5454
([4, 2, 5], 5),
5555
]
5656
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
57-
df = df.withColumn("max_value", F.array_max(F.col("firstColumn")))
57+
df = df.withColumn("max_value", sf.array_max(sf.col("firstColumn")))
5858
res = df.select("max_value").collect()
5959
assert res == [
6060
Row(max_value=3),
@@ -67,7 +67,7 @@ def test_array_min(self, spark):
6767
([2, 4, 5], 5),
6868
]
6969
df = spark.createDataFrame(data, ["firstColumn", "secondColumn"])
70-
df = df.withColumn("min_value", F.array_min(F.col("firstColumn")))
70+
df = df.withColumn("min_value", sf.array_min(sf.col("firstColumn")))
7171
res = df.select("min_value").collect()
7272
assert res == [
7373
Row(max_value=1),
@@ -77,58 +77,58 @@ def test_array_min(self, spark):
7777
def test_get(self, spark):
7878
df = spark.createDataFrame([(["a", "b", "c"], 1)], ['data', 'index'])
7979

80-
res = df.select(F.get(df.data, 1).alias("r")).collect()
80+
res = df.select(sf.get(df.data, 1).alias("r")).collect()
8181
assert res == [Row(r="b")]
8282

83-
res = df.select(F.get(df.data, -1).alias("r")).collect()
83+
res = df.select(sf.get(df.data, -1).alias("r")).collect()
8484
assert res == [Row(r=None)]
8585

86-
res = df.select(F.get(df.data, 3).alias("r")).collect()
86+
res = df.select(sf.get(df.data, 3).alias("r")).collect()
8787
assert res == [Row(r=None)]
8888

89-
res = df.select(F.get(df.data, "index").alias("r")).collect()
89+
res = df.select(sf.get(df.data, "index").alias("r")).collect()
9090
assert res == [Row(r='b')]
9191

92-
res = df.select(F.get(df.data, F.col("index") - 1).alias("r")).collect()
92+
res = df.select(sf.get(df.data, sf.col("index") - 1).alias("r")).collect()
9393
assert res == [Row(r='a')]
9494

9595
def test_flatten(self, spark):
9696
df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
9797

98-
res = df.select(F.flatten(df.data).alias("r")).collect()
98+
res = df.select(sf.flatten(df.data).alias("r")).collect()
9999
assert res == [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
100100

101101
def test_array_compact(self, spark):
102102
df = spark.createDataFrame([([1, None, 2, 3],), ([4, 5, None, 4],)], ['data'])
103103

104-
res = df.select(F.array_compact(df.data).alias("v")).collect()
104+
res = df.select(sf.array_compact(df.data).alias("v")).collect()
105105
assert [Row(v=[1, 2, 3]), Row(v=[4, 5, 4])]
106106

107107
def test_array_remove(self, spark):
108108
df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data'])
109109

110-
res = df.select(F.array_remove(df.data, 1).alias("v")).collect()
110+
res = df.select(sf.array_remove(df.data, 1).alias("v")).collect()
111111
assert res == [Row(v=[2, 3]), Row(v=[])]
112112

113113
def test_array_agg(self, spark):
114114
df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"])
115115

116-
res = df.groupBy("group").agg(F.array_agg("c").alias("r")).collect()
116+
res = df.groupBy("group").agg(sf.array_agg("c").alias("r")).collect()
117117
assert res[0] == Row(group="A", r=[1, 1, 2])
118118

119119
def test_collect_list(self, spark):
120120
df = spark.createDataFrame([[1, "A"], [1, "A"], [2, "A"]], ["c", "group"])
121121

122-
res = df.groupBy("group").agg(F.collect_list("c").alias("r")).collect()
122+
res = df.groupBy("group").agg(sf.collect_list("c").alias("r")).collect()
123123
assert res[0] == Row(group="A", r=[1, 1, 2])
124124

125125
def test_array_append(self, spark):
126126
df = spark.createDataFrame([Row(c1=["b", "a", "c"], c2="c")], ["c1", "c2"])
127127

128-
res = df.select(F.array_append(df.c1, df.c2).alias("r")).collect()
128+
res = df.select(sf.array_append(df.c1, df.c2).alias("r")).collect()
129129
assert res == [Row(r=['b', 'a', 'c', 'c'])]
130130

131-
res = df.select(F.array_append(df.c1, 'x')).collect()
131+
res = df.select(sf.array_append(df.c1, 'x')).collect()
132132
assert res == [Row(r=['b', 'a', 'c', 'x'])]
133133

134134
def test_array_insert(self, spark):
@@ -137,21 +137,21 @@ def test_array_insert(self, spark):
137137
['data', 'pos', 'val'],
138138
)
139139

140-
res = df.select(F.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
140+
res = df.select(sf.array_insert(df.data, df.pos.cast('integer'), df.val).alias('data')).collect()
141141
assert res == [
142142
Row(data=['a', 'd', 'b', 'c']),
143143
Row(data=['a', 'd', 'b', 'c', 'e']),
144144
Row(data=['c', 'b', 'd', 'a']),
145145
]
146146

147-
res = df.select(F.array_insert(df.data, 5, 'hello').alias('data')).collect()
147+
res = df.select(sf.array_insert(df.data, 5, 'hello').alias('data')).collect()
148148
assert res == [
149149
Row(data=['a', 'b', 'c', None, 'hello']),
150150
Row(data=['a', 'b', 'c', 'e', 'hello']),
151151
Row(data=['c', 'b', 'a', None, 'hello']),
152152
]
153153

154-
res = df.select(F.array_insert(df.data, -5, 'hello').alias('data')).collect()
154+
res = df.select(sf.array_insert(df.data, -5, 'hello').alias('data')).collect()
155155
assert res == [
156156
Row(data=['hello', None, 'a', 'b', 'c']),
157157
Row(data=['hello', 'a', 'b', 'c', 'e']),
@@ -160,67 +160,67 @@ def test_array_insert(self, spark):
160160

161161
def test_slice(self, spark):
162162
df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
163-
res = df.select(F.slice(df.x, 2, 2).alias("sliced")).collect()
163+
res = df.select(sf.slice(df.x, 2, 2).alias("sliced")).collect()
164164
assert res == [Row(sliced=[2, 3]), Row(sliced=[5])]
165165

166166
def test_sort_array(self, spark):
167167
df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data'])
168168

169-
res = df.select(F.sort_array(df.data).alias('r')).collect()
169+
res = df.select(sf.sort_array(df.data).alias('r')).collect()
170170
assert res == [Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
171171

172-
res = df.select(F.sort_array(df.data, asc=False).alias('r')).collect()
172+
res = df.select(sf.sort_array(df.data, asc=False).alias('r')).collect()
173173
assert res == [Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
174174

175175
@pytest.mark.parametrize(("null_replacement", "expected_joined_2"), [(None, "a"), ("replaced", "a,replaced")])
176176
def test_array_join(self, spark, null_replacement, expected_joined_2):
177177
df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
178178

179-
res = df.select(F.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect()
179+
res = df.select(sf.array_join(df.data, ",", null_replacement=null_replacement).alias("joined")).collect()
180180
assert res == [Row(joined='a,b,c'), Row(joined=expected_joined_2)]
181181

182182
def test_array_position(self, spark):
183183
df = spark.createDataFrame([(["c", "b", "a"],), ([],)], ['data'])
184184

185-
res = df.select(F.array_position(df.data, "a").alias("pos")).collect()
185+
res = df.select(sf.array_position(df.data, "a").alias("pos")).collect()
186186
assert res == [Row(pos=3), Row(pos=0)]
187187

188188
def test_array_preprend(self, spark):
189189
df = spark.createDataFrame([([2, 3, 4],), ([],)], ['data'])
190190

191-
res = df.select(F.array_prepend(df.data, 1).alias("pre")).collect()
191+
res = df.select(sf.array_prepend(df.data, 1).alias("pre")).collect()
192192
assert res == [Row(pre=[1, 2, 3, 4]), Row(pre=[1])]
193193

194194
def test_array_repeat(self, spark):
195195
df = spark.createDataFrame([('ab',)], ['data'])
196196

197-
res = df.select(F.array_repeat(df.data, 3).alias('r')).collect()
197+
res = df.select(sf.array_repeat(df.data, 3).alias('r')).collect()
198198
assert res == [Row(r=['ab', 'ab', 'ab'])]
199199

200200
def test_array_size(self, spark):
201201
df = spark.createDataFrame([([2, 1, 3],), (None,)], ['data'])
202202

203-
res = df.select(F.array_size(df.data).alias('r')).collect()
203+
res = df.select(sf.array_size(df.data).alias('r')).collect()
204204
assert res == [Row(r=3), Row(r=None)]
205205

206206
def test_array_sort(self, spark):
207207
df = spark.createDataFrame([([2, 1, None, 3],), ([1],), ([],)], ['data'])
208208

209-
res = df.select(F.array_sort(df.data).alias('r')).collect()
209+
res = df.select(sf.array_sort(df.data).alias('r')).collect()
210210
assert res == [Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
211211

212212
def test_arrays_overlap(self, spark):
213213
df = spark.createDataFrame(
214214
[(["a", "b"], ["b", "c"]), (["a"], ["b", "c"]), ([None, "c"], ["a"]), ([None, "c"], [None])], ['x', 'y']
215215
)
216216

217-
res = df.select(F.arrays_overlap(df.x, df.y).alias("overlap")).collect()
217+
res = df.select(sf.arrays_overlap(df.x, df.y).alias("overlap")).collect()
218218
assert res == [Row(overlap=True), Row(overlap=False), Row(overlap=None), Row(overlap=None)]
219219

220220
def test_arrays_zip(self, spark):
221221
df = spark.createDataFrame([([1, 2, 3], [2, 4, 6], [3, 6])], ['vals1', 'vals2', 'vals3'])
222222

223-
res = df.select(F.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect()
223+
res = df.select(sf.arrays_zip(df.vals1, df.vals2, df.vals3).alias('zipped')).collect()
224224
# FIXME: The structure of the results should be the same
225225
if USE_ACTUAL_SPARK:
226226
assert res == [

0 commit comments

Comments
 (0)