Skip to content

Commit 984e16b

Browse files
committed
[SPARK-53657][PYTHON][TESTS] Enable doctests for GroupedData.agg
### What changes were proposed in this pull request? Enable doctests for `GroupedData.agg`, some doctests were skipped due to dependency on pandas/pyarrow installations. ### Why are the changes needed? to improve test coverage ### Does this PR introduce _any_ user-facing change? doc-only changes ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52404 from zhengruifeng/enable_group_agg_doctest. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent f48de10 commit 984e16b

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

python/pyspark/sql/connect/group.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,13 @@ def _test() -> None:
583583
import doctest
584584
from pyspark.sql import SparkSession as PySparkSession
585585
import pyspark.sql.connect.group
586+
from pyspark.testing.utils import have_pandas, have_pyarrow
586587

587588
globs = pyspark.sql.connect.group.__dict__.copy()
588589

590+
if not have_pandas or not have_pyarrow:
591+
del pyspark.sql.connect.group.GroupedData.agg.__doc__
592+
589593
globs["spark"] = (
590594
PySparkSession.builder.appName("sql.connect.group tests")
591595
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))

python/pyspark/sql/group.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,8 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
126126
127127
Examples
128128
--------
129-
>>> import pandas as pd # doctest: +SKIP
129+
>>> import pandas as pd
130130
>>> from pyspark.sql import functions as sf
131-
>>> from pyspark.sql.functions import pandas_udf
132131
>>> df = spark.createDataFrame(
133132
... [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"])
134133
>>> df.show()
@@ -166,11 +165,12 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
166165
167166
Same as above but uses pandas UDF.
168167
169-
>>> @pandas_udf('int') # doctest: +SKIP
168+
>>> from pyspark.sql.functions import pandas_udf
169+
>>> @pandas_udf('int')
170170
... def min_udf(v: pd.Series) -> int:
171171
... return v.min()
172172
...
173-
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show() # doctest: +SKIP
173+
>>> df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show()
174174
+-----+------------+
175175
| name|min_udf(age)|
176176
+-----+------------+
@@ -533,8 +533,13 @@ def _test() -> None:
533533
import doctest
534534
from pyspark.sql import SparkSession
535535
import pyspark.sql.group
536+
from pyspark.testing.utils import have_pandas, have_pyarrow
536537

537538
globs = pyspark.sql.group.__dict__.copy()
539+
540+
if not have_pandas or not have_pyarrow:
541+
del pyspark.sql.group.GroupedData.agg.__doc__
542+
538543
spark = SparkSession.builder.master("local[4]").appName("sql.group tests").getOrCreate()
539544
globs["spark"] = spark
540545

0 commit comments

Comments
 (0)