@@ -245,92 +245,13 @@ def __repr__(self):
245
245
return "ArrowStreamSerializer"
246
246
247
247
248
- def _create_batch ( series , timezone , safecheck , assign_cols_by_name ):
248
+ class ArrowStreamPandasSerializer ( ArrowStreamSerializer ):
249
249
"""
250
- Create an Arrow record batch from the given pandas .Series or list of Series, with optional type .
250
+ Serializes Pandas .Series as Arrow data with Arrow streaming format .
251
251
252
- :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
253
252
:param timezone: A timezone to respect when handling timestamp values
254
- :return: Arrow RecordBatch
255
- """
256
- import decimal
257
- from distutils .version import LooseVersion
258
- import pandas as pd
259
- import pyarrow as pa
260
- from pyspark .sql .types import _check_series_convert_timestamps_internal
261
- # Make input conform to [(series1, type1), (series2, type2), ...]
262
- if not isinstance (series , (list , tuple )) or \
263
- (len (series ) == 2 and isinstance (series [1 ], pa .DataType )):
264
- series = [series ]
265
- series = ((s , None ) if not isinstance (s , (list , tuple )) else s for s in series )
266
-
267
- def create_array (s , t ):
268
- mask = s .isnull ()
269
- # Ensure timestamp series are in expected form for Spark internal representation
270
- # TODO: maybe don't need None check anymore as of Arrow 0.9.1
271
- if t is not None and pa .types .is_timestamp (t ):
272
- s = _check_series_convert_timestamps_internal (s .fillna (0 ), timezone )
273
- # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
274
- return pa .Array .from_pandas (s , mask = mask ).cast (t , safe = False )
275
- elif t is not None and pa .types .is_string (t ) and sys .version < '3' :
276
- # TODO: need decode before converting to Arrow in Python 2
277
- # TODO: don't need as of Arrow 0.9.1
278
- return pa .Array .from_pandas (s .apply (
279
- lambda v : v .decode ("utf-8" ) if isinstance (v , str ) else v ), mask = mask , type = t )
280
- elif t is not None and pa .types .is_decimal (t ) and \
281
- LooseVersion ("0.9.0" ) <= LooseVersion (pa .__version__ ) < LooseVersion ("0.10.0" ):
282
- # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
283
- return pa .Array .from_pandas (s .apply (
284
- lambda v : decimal .Decimal ('NaN' ) if v is None else v ), mask = mask , type = t )
285
- elif LooseVersion (pa .__version__ ) < LooseVersion ("0.11.0" ):
286
- # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
287
- return pa .Array .from_pandas (s , mask = mask , type = t )
288
-
289
- try :
290
- array = pa .Array .from_pandas (s , mask = mask , type = t , safe = safecheck )
291
- except pa .ArrowException as e :
292
- error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
293
- "Array (%s). It can be caused by overflows or other unsafe " + \
294
- "conversions warned by Arrow. Arrow safe type check can be " + \
295
- "disabled by using SQL config " + \
296
- "`spark.sql.execution.pandas.arrowSafeTypeConversion`."
297
- raise RuntimeError (error_msg % (s .dtype , t ), e )
298
- return array
299
-
300
- arrs = []
301
- for s , t in series :
302
- if t is not None and pa .types .is_struct (t ):
303
- if not isinstance (s , pd .DataFrame ):
304
- raise ValueError ("A field of type StructType expects a pandas.DataFrame, "
305
- "but got: %s" % str (type (s )))
306
-
307
- # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
308
- if len (s ) == 0 and len (s .columns ) == 0 :
309
- arrs_names = [(pa .array ([], type = field .type ), field .name ) for field in t ]
310
- # Assign result columns by schema name if user labeled with strings
311
- elif assign_cols_by_name and any (isinstance (name , basestring ) for name in s .columns ):
312
- arrs_names = [(create_array (s [field .name ], field .type ), field .name ) for field in t ]
313
- # Assign result columns by position
314
- else :
315
- arrs_names = [(create_array (s [s .columns [i ]], field .type ), field .name )
316
- for i , field in enumerate (t )]
317
-
318
- struct_arrs , struct_names = zip (* arrs_names )
319
-
320
- # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
321
- if LooseVersion (pa .__version__ ) < LooseVersion ("0.9.0" ):
322
- arrs .append (pa .StructArray .from_arrays (struct_names , struct_arrs ))
323
- else :
324
- arrs .append (pa .StructArray .from_arrays (struct_arrs , struct_names ))
325
- else :
326
- arrs .append (create_array (s , t ))
327
-
328
- return pa .RecordBatch .from_arrays (arrs , ["_%d" % i for i in xrange (len (arrs ))])
329
-
330
-
331
- class ArrowStreamPandasSerializer (Serializer ):
332
- """
333
- Serializes Pandas.Series as Arrow data with Arrow streaming format.
253
+ :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
254
+ :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
334
255
"""
335
256
336
257
def __init__ (self , timezone , safecheck , assign_cols_by_name ):
@@ -347,39 +268,138 @@ def arrow_to_pandas(self, arrow_column):
347
268
s = _check_series_localize_timestamps (s , self ._timezone )
348
269
return s
349
270
271
+ def _create_batch (self , series ):
272
+ """
273
+ Create an Arrow record batch from the given pandas.Series or list of Series,
274
+ with optional type.
275
+
276
+ :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
277
+ :return: Arrow RecordBatch
278
+ """
279
+ import decimal
280
+ from distutils .version import LooseVersion
281
+ import pandas as pd
282
+ import pyarrow as pa
283
+ from pyspark .sql .types import _check_series_convert_timestamps_internal
284
+ # Make input conform to [(series1, type1), (series2, type2), ...]
285
+ if not isinstance (series , (list , tuple )) or \
286
+ (len (series ) == 2 and isinstance (series [1 ], pa .DataType )):
287
+ series = [series ]
288
+ series = ((s , None ) if not isinstance (s , (list , tuple )) else s for s in series )
289
+
290
+ def create_array (s , t ):
291
+ mask = s .isnull ()
292
+ # Ensure timestamp series are in expected form for Spark internal representation
293
+ # TODO: maybe don't need None check anymore as of Arrow 0.9.1
294
+ if t is not None and pa .types .is_timestamp (t ):
295
+ s = _check_series_convert_timestamps_internal (s .fillna (0 ), self ._timezone )
296
+ # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
297
+ return pa .Array .from_pandas (s , mask = mask ).cast (t , safe = False )
298
+ elif t is not None and pa .types .is_string (t ) and sys .version < '3' :
299
+ # TODO: need decode before converting to Arrow in Python 2
300
+ # TODO: don't need as of Arrow 0.9.1
301
+ return pa .Array .from_pandas (s .apply (
302
+ lambda v : v .decode ("utf-8" ) if isinstance (v , str ) else v ), mask = mask , type = t )
303
+ elif t is not None and pa .types .is_decimal (t ) and \
304
+ LooseVersion ("0.9.0" ) <= LooseVersion (pa .__version__ ) < LooseVersion ("0.10.0" ):
305
+ # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
306
+ return pa .Array .from_pandas (s .apply (
307
+ lambda v : decimal .Decimal ('NaN' ) if v is None else v ), mask = mask , type = t )
308
+ elif LooseVersion (pa .__version__ ) < LooseVersion ("0.11.0" ):
309
+ # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
310
+ return pa .Array .from_pandas (s , mask = mask , type = t )
311
+
312
+ try :
313
+ array = pa .Array .from_pandas (s , mask = mask , type = t , safe = self ._safecheck )
314
+ except pa .ArrowException as e :
315
+ error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
316
+ "Array (%s). It can be caused by overflows or other unsafe " + \
317
+ "conversions warned by Arrow. Arrow safe type check can be " + \
318
+ "disabled by using SQL config " + \
319
+ "`spark.sql.execution.pandas.arrowSafeTypeConversion`."
320
+ raise RuntimeError (error_msg % (s .dtype , t ), e )
321
+ return array
322
+
323
+ arrs = []
324
+ for s , t in series :
325
+ if t is not None and pa .types .is_struct (t ):
326
+ if not isinstance (s , pd .DataFrame ):
327
+ raise ValueError ("A field of type StructType expects a pandas.DataFrame, "
328
+ "but got: %s" % str (type (s )))
329
+
330
+ # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
331
+ if len (s ) == 0 and len (s .columns ) == 0 :
332
+ arrs_names = [(pa .array ([], type = field .type ), field .name ) for field in t ]
333
+ # Assign result columns by schema name if user labeled with strings
334
+ elif self ._assign_cols_by_name and any (isinstance (name , basestring )
335
+ for name in s .columns ):
336
+ arrs_names = [(create_array (s [field .name ], field .type ), field .name )
337
+ for field in t ]
338
+ # Assign result columns by position
339
+ else :
340
+ arrs_names = [(create_array (s [s .columns [i ]], field .type ), field .name )
341
+ for i , field in enumerate (t )]
342
+
343
+ struct_arrs , struct_names = zip (* arrs_names )
344
+
345
+ # TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
346
+ if LooseVersion (pa .__version__ ) < LooseVersion ("0.9.0" ):
347
+ arrs .append (pa .StructArray .from_arrays (struct_names , struct_arrs ))
348
+ else :
349
+ arrs .append (pa .StructArray .from_arrays (struct_arrs , struct_names ))
350
+ else :
351
+ arrs .append (create_array (s , t ))
352
+
353
+ return pa .RecordBatch .from_arrays (arrs , ["_%d" % i for i in xrange (len (arrs ))])
354
+
350
355
def dump_stream (self , iterator , stream ):
351
356
"""
352
357
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
353
358
a list of series accompanied by an optional pyarrow type to coerce the data to.
354
359
"""
355
- import pyarrow as pa
356
- writer = None
357
- try :
358
- for series in iterator :
359
- batch = _create_batch (series , self ._timezone , self ._safecheck ,
360
- self ._assign_cols_by_name )
361
- if writer is None :
362
- write_int (SpecialLengths .START_ARROW_STREAM , stream )
363
- writer = pa .RecordBatchStreamWriter (stream , batch .schema )
364
- writer .write_batch (batch )
365
- finally :
366
- if writer is not None :
367
- writer .close ()
360
+ batches = (self ._create_batch (series ) for series in iterator )
361
+ super (ArrowStreamPandasSerializer , self ).dump_stream (batches , stream )
368
362
369
363
def load_stream (self , stream ):
370
364
"""
371
365
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
372
366
"""
367
+ batches = super (ArrowStreamPandasSerializer , self ).load_stream (stream )
373
368
import pyarrow as pa
374
- reader = pa .ipc .open_stream (stream )
375
-
376
- for batch in reader :
369
+ for batch in batches :
377
370
yield [self .arrow_to_pandas (c ) for c in pa .Table .from_batches ([batch ]).itercolumns ()]
378
371
379
372
def __repr__ (self ):
380
373
return "ArrowStreamPandasSerializer"
381
374
382
375
376
+ class ArrowStreamPandasUDFSerializer (ArrowStreamPandasSerializer ):
377
+ """
378
+ Serializer used by Python worker to evaluate Pandas UDFs
379
+ """
380
+
381
+ def dump_stream (self , iterator , stream ):
382
+ """
383
+ Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
384
+ This should be sent after creating the first record batch so in case of an error, it can
385
+ be sent back to the JVM before the Arrow stream starts.
386
+ """
387
+
388
+ def init_stream_yield_batches ():
389
+ should_write_start_length = True
390
+ for series in iterator :
391
+ batch = self ._create_batch (series )
392
+ if should_write_start_length :
393
+ write_int (SpecialLengths .START_ARROW_STREAM , stream )
394
+ should_write_start_length = False
395
+ yield batch
396
+
397
+ return ArrowStreamSerializer .dump_stream (self , init_stream_yield_batches (), stream )
398
+
399
+ def __repr__ (self ):
400
+ return "ArrowStreamPandasUDFSerializer"
401
+
402
+
383
403
class BatchedSerializer (Serializer ):
384
404
385
405
"""
0 commit comments