diff --git a/plugins/flytekit-spark/flytekitplugins/spark/schema.py b/plugins/flytekit-spark/flytekitplugins/spark/schema.py index fab60f485a..4c423e6894 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/schema.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/schema.py @@ -84,6 +84,84 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return r.all() +classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") +ClassicDataFrame = classic_ps_dataframe.DataFrame + + +class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]): + """ + Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema + """ + + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(from_path, cols, fmt) + + def iter(self, **kwargs) -> typing.Generator[T, None, None]: + raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks") + + def all(self, **kwargs) -> ClassicDataFrame: + if self._fmt == SchemaFormat.PARQUET: + ctx = FlyteContext.current_context().user_space_params + return ctx.spark_session.read.parquet(self.from_path) + raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") + + +class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]): + """ + Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema + """ + + def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(to_path, cols, fmt) + + def write(self, *dfs: ClassicDataFrame, **kwargs): + if dfs is None or len(dfs) == 0: + return + if len(dfs) > 1: + raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently") + if self._fmt == SchemaFormat.PARQUET: + dfs[0].write.mode("overwrite").parquet(self.to_path) + return + raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") + + +class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]): + """ + Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped) + """ + + def __init__(self): + super().__init__("classic-spark-df-transformer", t=ClassicDataFrame) + + @staticmethod + def _get_schema_type() -> SchemaType: + return SchemaType(columns=[]) + + def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType: + return LiteralType(schema=self._get_schema_type()) + + def to_literal( + self, + ctx: FlyteContext, + python_val: ClassicDataFrame, + python_type: Type[ClassicDataFrame], + expected: LiteralType, + ) -> Literal: + remote_path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET) + w.write(python_val) + return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T: + if not (lv and lv.scalar and lv.scalar.schema): + return ClassicDataFrame() + r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET) + return r.all() + + # %% # Registers a handle for Spark DataFrame + Flyte Schema type transition # This allows open(pyspark.DataFrame) to be an acceptable type @@ -97,6 +175,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) ) +SchemaEngine.register_handler( + SchemaHandler( + "pyspark.sql.classic.DataFrame-Schema", + ClassicDataFrame, + ClassicSparkDataFrameSchemaReader, + ClassicSparkDataFrameSchemaWriter, + handles_remote_io=True, + ) +) # %% # This makes pyspark.DataFrame as a supported output/input type with flytekit. TypeEngine.register(SparkDataFrameTransformer()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 2a0faa1b5d..a849b711dc 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -72,3 +72,52 @@ def decode( StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer()) + +classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") +ClassicDataFrame = classic_ps_dataframe.DataFrame + + +class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + path = typing.cast(str, structured_dataset.uri) + if not path: + path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + df = typing.cast(ClassicDataFrame, structured_dataset.dataframe) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) + return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + +class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> ClassicDataFrame: + user_ctx = FlyteContext.current_context().user_space_params + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) + return user_ctx.spark_session.read.parquet(flyte_value.uri) + + +# Register the handlers +StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer()) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 4dffeb8b6e..dd620947b9 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,8 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -# TODO: Add support spark 4.0.0, https://github.com/flyteorg/flyte/issues/6478 -plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0,<4.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"] +plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"] __version__ = "0.0.0+develop"