forked from flyteorg/flytekit
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsd_transformers.py
More file actions
123 lines (101 loc) · 4.86 KB
/
sd_transformers.py
File metadata and controls
123 lines (101 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import typing
from flytekit import FlyteContext, lazy_module
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)
pd = lazy_module("pandas")
pyspark = lazy_module("pyspark")
ps_dataframe = lazy_module("pyspark.sql.dataframe")
DataFrame = ps_dataframe.DataFrame
class SparkDataFrameRenderer:
"""
Render a Spark dataframe schema as an HTML table.
"""
def to_html(self, df: DataFrame) -> str:
assert isinstance(df, DataFrame)
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()
class SparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(DataFrame, None, PARQUET)
def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri)
if not path:
path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
df = typing.cast(DataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
# Avoid generating SUCCESS files
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
class ParquetToSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(DataFrame, None, PARQUET)
def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> DataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
ClassicDataFrame = classic_ps_dataframe.DataFrame
class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)
def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri)
if not path:
path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))
class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)
def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> ClassicDataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)
# Register the handlers
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())