@@ -84,6 +84,84 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
8484 return r .all ()
8585
8686
87+ classic_ps_dataframe = lazy_module ("pyspark.sql.classic.dataframe" )
88+ ClassicDataFrame = classic_ps_dataframe .DataFrame
89+
90+
91+ class ClassicSparkDataFrameSchemaReader (SchemaReader [ClassicDataFrame ]):
92+ """
93+ Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
94+ """
95+
96+ def __init__ (self , from_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
97+ super ().__init__ (from_path , cols , fmt )
98+
99+ def iter (self , ** kwargs ) -> typing .Generator [T , None , None ]:
100+ raise NotImplementedError ("Classic Spark DataFrame reader cannot iterate over individual chunks" )
101+
102+ def all (self , ** kwargs ) -> ClassicDataFrame :
103+ if self ._fmt == SchemaFormat .PARQUET :
104+ ctx = FlyteContext .current_context ().user_space_params
105+ return ctx .spark_session .read .parquet (self .from_path )
106+ raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
107+
108+
109+ class ClassicSparkDataFrameSchemaWriter (SchemaWriter [ClassicDataFrame ]):
110+ """
111+ Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
112+ """
113+
114+ def __init__ (self , to_path : str , cols : typing .Optional [typing .Dict [str , type ]], fmt : SchemaFormat ):
115+ super ().__init__ (to_path , cols , fmt )
116+
117+ def write (self , * dfs : ClassicDataFrame , ** kwargs ):
118+ if dfs is None or len (dfs ) == 0 :
119+ return
120+ if len (dfs ) > 1 :
121+ raise AssertionError ("Only a single Classic Spark.DataFrame can be written per variable currently" )
122+ if self ._fmt == SchemaFormat .PARQUET :
123+ dfs [0 ].write .mode ("overwrite" ).parquet (self .to_path )
124+ return
125+ raise AssertionError ("Only Parquet type files are supported for classic spark dataframe currently" )
126+
127+
128+ class ClassicSparkDataFrameTransformer (TypeTransformer [ClassicDataFrame ]):
129+ """
130+ Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
131+ """
132+
133+ def __init__ (self ):
134+ super ().__init__ ("classic-spark-df-transformer" , t = ClassicDataFrame )
135+
136+ @staticmethod
137+ def _get_schema_type () -> SchemaType :
138+ return SchemaType (columns = [])
139+
140+ def get_literal_type (self , t : Type [ClassicDataFrame ]) -> LiteralType :
141+ return LiteralType (schema = self ._get_schema_type ())
142+
143+ def to_literal (
144+ self ,
145+ ctx : FlyteContext ,
146+ python_val : ClassicDataFrame ,
147+ python_type : Type [ClassicDataFrame ],
148+ expected : LiteralType ,
149+ ) -> Literal :
150+ remote_path = ctx .file_access .join (
151+ ctx .file_access .raw_output_prefix ,
152+ ctx .file_access .get_random_string (),
153+ )
154+ w = ClassicSparkDataFrameSchemaWriter (to_path = remote_path , cols = None , fmt = SchemaFormat .PARQUET )
155+ w .write (python_val )
156+ return Literal (scalar = Scalar (schema = Schema (remote_path , self ._get_schema_type ())))
157+
158+ def to_python_value (self , ctx : FlyteContext , lv : Literal , expected_python_type : Type [ClassicDataFrame ]) -> T :
159+ if not (lv and lv .scalar and lv .scalar .schema ):
160+ return ClassicDataFrame ()
161+ r = ClassicSparkDataFrameSchemaReader (from_path = lv .scalar .schema .uri , cols = None , fmt = SchemaFormat .PARQUET )
162+ return r .all ()
163+
164+
87165# %%
88166# Registers a handle for Spark DataFrame + Flyte Schema type transition
89167# This allows open(pyspark.DataFrame) to be an acceptable type
@@ -97,6 +175,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
97175 )
98176)
99177
178+ SchemaEngine .register_handler (
179+ SchemaHandler (
180+ "pyspark.sql.classic.DataFrame-Schema" ,
181+ ClassicDataFrame ,
182+ ClassicSparkDataFrameSchemaReader ,
183+ ClassicSparkDataFrameSchemaWriter ,
184+ handles_remote_io = True ,
185+ )
186+ )
100187# %%
101188# This makes pyspark.DataFrame as a supported output/input type with flytekit.
102189TypeEngine .register (SparkDataFrameTransformer ())
0 commit comments