Skip to content

Commit 68f2c60

Browse files
authored
Add support pyspark.sql.classic.dataframe.DataFrame transformer (#3272)
Signed-off-by: Nelson Chen <asd3431090@gmail.com>
1 parent eb5a67f commit 68f2c60

File tree

3 files changed

+137
-2
lines changed

3 files changed

+137
-2
lines changed

plugins/flytekit-spark/flytekitplugins/spark/schema.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,84 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
8484
return r.all()
8585

8686

87+
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
88+
ClassicDataFrame = classic_ps_dataframe.DataFrame
89+
90+
91+
class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]):
92+
"""
93+
Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
94+
"""
95+
96+
def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
97+
super().__init__(from_path, cols, fmt)
98+
99+
def iter(self, **kwargs) -> typing.Generator[T, None, None]:
100+
raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks")
101+
102+
def all(self, **kwargs) -> ClassicDataFrame:
103+
if self._fmt == SchemaFormat.PARQUET:
104+
ctx = FlyteContext.current_context().user_space_params
105+
return ctx.spark_session.read.parquet(self.from_path)
106+
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
107+
108+
109+
class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]):
110+
"""
111+
Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
112+
"""
113+
114+
def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
115+
super().__init__(to_path, cols, fmt)
116+
117+
def write(self, *dfs: ClassicDataFrame, **kwargs):
118+
if dfs is None or len(dfs) == 0:
119+
return
120+
if len(dfs) > 1:
121+
raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently")
122+
if self._fmt == SchemaFormat.PARQUET:
123+
dfs[0].write.mode("overwrite").parquet(self.to_path)
124+
return
125+
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")
126+
127+
128+
class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]):
129+
"""
130+
Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
131+
"""
132+
133+
def __init__(self):
134+
super().__init__("classic-spark-df-transformer", t=ClassicDataFrame)
135+
136+
@staticmethod
137+
def _get_schema_type() -> SchemaType:
138+
return SchemaType(columns=[])
139+
140+
def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType:
141+
return LiteralType(schema=self._get_schema_type())
142+
143+
def to_literal(
144+
self,
145+
ctx: FlyteContext,
146+
python_val: ClassicDataFrame,
147+
python_type: Type[ClassicDataFrame],
148+
expected: LiteralType,
149+
) -> Literal:
150+
remote_path = ctx.file_access.join(
151+
ctx.file_access.raw_output_prefix,
152+
ctx.file_access.get_random_string(),
153+
)
154+
w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
155+
w.write(python_val)
156+
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))
157+
158+
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T:
159+
if not (lv and lv.scalar and lv.scalar.schema):
160+
return ClassicDataFrame()
161+
r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
162+
return r.all()
163+
164+
87165
# %%
88166
# Registers a handle for Spark DataFrame + Flyte Schema type transition
89167
# This allows open(pyspark.DataFrame) to be an acceptable type
@@ -97,6 +175,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
97175
)
98176
)
99177

178+
SchemaEngine.register_handler(
179+
SchemaHandler(
180+
"pyspark.sql.classic.DataFrame-Schema",
181+
ClassicDataFrame,
182+
ClassicSparkDataFrameSchemaReader,
183+
ClassicSparkDataFrameSchemaWriter,
184+
handles_remote_io=True,
185+
)
186+
)
100187
# %%
101188
# This makes pyspark.DataFrame as a supported output/input type with flytekit.
102189
TypeEngine.register(SparkDataFrameTransformer())

plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,52 @@ def decode(
7272
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler())
7373
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
7474
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
75+
76+
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
77+
ClassicDataFrame = classic_ps_dataframe.DataFrame
78+
79+
80+
class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
81+
def __init__(self):
82+
super().__init__(ClassicDataFrame, None, PARQUET)
83+
84+
def encode(
85+
self,
86+
ctx: FlyteContext,
87+
structured_dataset: StructuredDataset,
88+
structured_dataset_type: StructuredDatasetType,
89+
) -> literals.StructuredDataset:
90+
path = typing.cast(str, structured_dataset.uri)
91+
if not path:
92+
path = ctx.file_access.join(
93+
ctx.file_access.raw_output_prefix,
94+
ctx.file_access.get_random_string(),
95+
)
96+
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
97+
ss = pyspark.sql.SparkSession.builder.getOrCreate()
98+
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
99+
df.write.mode("overwrite").parquet(path=path)
100+
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
101+
102+
103+
class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
104+
def __init__(self):
105+
super().__init__(ClassicDataFrame, None, PARQUET)
106+
107+
def decode(
108+
self,
109+
ctx: FlyteContext,
110+
flyte_value: literals.StructuredDataset,
111+
current_task_metadata: StructuredDatasetMetadata,
112+
) -> ClassicDataFrame:
113+
user_ctx = FlyteContext.current_context().user_space_params
114+
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
115+
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
116+
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
117+
return user_ctx.spark_session.read.parquet(flyte_value.uri)
118+
119+
120+
# Register the handlers
121+
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
122+
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
123+
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())

plugins/flytekit-spark/setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
66

7-
# TODO: Add support spark 4.0.0, https://github.com/flyteorg/flyte/issues/6478
8-
plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0,<4.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
7+
plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
98

109
__version__ = "0.0.0+develop"
1110

0 commit comments

Comments
 (0)