From cb8b3a77561348dc89a9bbf833b722298a5e8b83 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Thu, 28 Aug 2025 16:05:46 +0200 Subject: [PATCH] Fwd port of https://github.com/duckdb/duckdb/pull/15789/files --- duckdb/experimental/spark/sql/dataframe.py | 30 ++++++++++++++++++++++ tests/fast/spark/test_spark_dataframe.py | 8 ++++++ 2 files changed, 38 insertions(+) diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index 31b13ded..b8a4698b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1403,5 +1403,35 @@ def construct_row(values, names) -> Row: rows = [construct_row(x, columns) for x in result] return rows + def cache(self) -> "DataFrame": + """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK_DESER`). + + .. versionadded:: 1.3.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Notes + ----- + The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. + + Returns + ------- + :class:`DataFrame` + Cached DataFrame. + + Examples + -------- + >>> df = spark.range(1) + >>> df.cache() + DataFrame[id: bigint] + + >>> df.explain() + == Physical Plan == + InMemoryTableScan ... + """ + cached_relation = self.relation.execute() + return DataFrame(cached_relation, self.session) + __all__ = ["DataFrame"] diff --git a/tests/fast/spark/test_spark_dataframe.py b/tests/fast/spark/test_spark_dataframe.py index 5b7492d7..d88b03eb 100644 --- a/tests/fast/spark/test_spark_dataframe.py +++ b/tests/fast/spark/test_spark_dataframe.py @@ -421,3 +421,11 @@ def test_drop(self, spark): assert df.drop("two", "three").columns == expected assert df.drop("two", col("three")).columns == expected assert df.drop("two", col("three"), col("missing")).columns == expected + + def test_cache(self, spark): + data = [(1, 2, 3, 4)] + df = spark.createDataFrame(data, ["one", "two", "three", "four"]) + cached = df.cache() + assert df is not cached + assert cached.collect() == df.collect() + assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]