Skip to content

Commit a0c2fa6

Browse files
zero323dongjoon-hyun
authored andcommitted
[SPARK-28439][PYTHON][SQL] Add support for count: Column in array_repeat
## What changes were proposed in this pull request? This adds simple check for `count` argument: - If it is a `Column` we apply `_to_java_column` before invoking JVM counterpart - Otherwise we proceed as before. ## How was this patch tested? Manual testing. Closes apache#25193 from zero323/SPARK-28278. Authored-by: zero323 <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 2cf0491 commit a0c2fa6

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2698,7 +2698,10 @@ def array_repeat(col, count):
26982698
[Row(r=[u'ab', u'ab', u'ab'])]
26992699
"""
27002700
sc = SparkContext._active_spark_context
2701-
return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
2701+
return Column(sc._jvm.functions.array_repeat(
2702+
_to_java_column(col),
2703+
_to_java_column(count) if isinstance(count, Column) else count
2704+
))
27022705

27032706

27042707
@since(2.4)

python/pyspark/sql/tests/test_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,16 @@ def test_input_file_name_reset_for_rdd(self):
294294
for result in results:
295295
self.assertEqual(result[0], '')
296296

297+
def test_array_repeat(self):
298+
from pyspark.sql.functions import array_repeat, lit
299+
300+
df = self.spark.range(1)
301+
302+
self.assertEquals(
303+
df.select(array_repeat("id", 3)).toDF("val").collect(),
304+
df.select(array_repeat("id", lit(3))).toDF("val").collect(),
305+
)
306+
297307

298308
if __name__ == "__main__":
299309
import unittest

0 commit comments

Comments
 (0)