33"""
44
55from abc import ABC
6- from typing import Any , Callable , Dict
6+ from typing import Any , Callable , Dict , List , Optional
77import datetime
88import types
99import numbers
2121# TODO: possibly import dask, Nvidia Rapids, modin, datatable versions
2222
2323
24+ def none_mark_scalar_or_length (v ) -> Optional [int ]:
25+ """
26+ Test if item is a scalar (returning None) if it is, else length of object.
27+
28+ :param v: value to test
29+ :return: None if value is a scalar, else length.
30+ """
31+ # get some of the obvious types, and str (as str doesn't throw on len)
32+ if isinstance (v , (type (None ), str , int , float )):
33+ return None # obvious scalar
34+ # len() throws on scalars other than str
35+ try :
36+ return len (v )
37+ except TypeError :
38+ return None # len() failed, probably a scalar
39+
40+
41+ def promote_scalar_to_array (vi , * , target_len : int ) -> List :
42+ """
43+ Convert a scalar into a vector. Pass a non-trivial array through.
44+
45+ :param vi: value to promote to scalar
46+ :target_len: length for vector
47+ :return: list
48+ """
49+ assert isinstance (target_len , int )
50+ assert target_len >= 0
51+ if target_len <= 0 :
52+ return []
53+ len_v = none_mark_scalar_or_length (vi )
54+ # noinspection PyBroadException
55+ if len_v is None :
56+ return [vi ] * target_len # scalar
57+ if len_v == target_len :
58+ return vi
59+ if len_v == 1 :
60+ return [vi [0 ]] * target_len # TODO: see if we can eliminate this one
61+ else :
62+ raise ValueError ("incompatible column lengths" )
63+
64+
2465def _negate_or_subtract (* args ):
2566 if len (args ) == 1 :
2667 return numpy .negative (args [0 ])
@@ -436,56 +477,48 @@ def sql_proxy_step(self, op, *, data_map: dict, narrow: bool):
436477 res = db_handle .read_query ("\n " .join (op .sql ))
437478 return res
438479
439- def columns_to_frame_ (self , cols , * , target_rows = 0 ):
480+ def columns_to_frame_ (self , cols : Dict [ str , Any ], * , target_rows : Optional [ int ] = None ):
440481 """
441482 Convert a dictionary of column names to series-like objects and scalars into a Pandas data frame.
483+ Deal with special cases, such as some columns coming in as scalars (often from Panda aggregation).
442484
443485 :param cols: dictionary mapping column names to columns
444486 :param target_rows: number of rows we are shooting for
445487 :return: Pandas data frame.
446488 """
447489 # noinspection PyUnresolvedReferences
448490 assert isinstance (cols , dict )
491+ assert isinstance (target_rows , (int , type (None )))
492+ if target_rows is not None :
493+ assert target_rows >= 0
449494 if len (cols ) < 1 :
450- return self .pd .DataFrame (cols )
451- for k , v in cols .items ():
452- try :
453- target_rows = max (target_rows , len (v ))
454- except TypeError :
455- target_rows = max (target_rows , 1 ) # scalar
495+ # all scalars, so nothing carrying index information
496+ if target_rows is not None :
497+ return self .pd .DataFrame ({}, index = range (target_rows )).reset_index (drop = True , inplace = False )
498+ else :
499+ return self .pd .DataFrame ({})
500+ was_all_scalars = True
501+ for v in cols .values ():
502+ ln = none_mark_scalar_or_length (v )
503+ if ln is not None :
504+ was_all_scalars = False
505+ if target_rows is None :
506+ target_rows = ln
507+ else :
508+ assert target_rows == ln
509+ if was_all_scalars :
510+ if target_rows is None :
511+ target_rows = 1
512+ # all scalars, so nothing carrying index information
513+ promoted_cols = {k : promote_scalar_to_array (v , target_len = target_rows ) for (k , v ) in cols .items ()}
514+ return self .pd .DataFrame (promoted_cols , index = range (target_rows )).reset_index (drop = True , inplace = False )
515+ assert target_rows is not None
456516 if target_rows < 1 :
457- # noinspection PyBroadException
458- try :
459- res = self .pd .DataFrame (cols )
460- if res .shape [0 ] > 0 :
461- res = res .loc [[False ] * res .shape [0 ], :].reset_index (
462- drop = True , inplace = False
463- )
464- except Exception :
465- res = self .pd .DataFrame ({k : [] for k in cols .keys ()})
466- return res
467-
517+ # no rows, so presuming no index information (shouldn't have come from an aggregation)
518+ return self .pd .DataFrame ({k : [] for k in cols .keys ()})
468519 # agg can return scalars, which then can't be made into a self.pd.DataFrame
469- def promote_scalar (vi , * , target_len ):
470- """
471- Convert a scalar into a vector.
472- """
473- # noinspection PyBroadException
474- try :
475- len_v = len (vi )
476- if len_v != target_len :
477- if len_v == 0 :
478- return [None ] * target_len
479- elif len_v == 1 :
480- return [vi [0 ]] * target_len
481- else :
482- raise ValueError ("incompatible column lengths" )
483- except Exception :
484- return [vi ] * target_len # scalar
485- return vi
486-
487- cols = {k : promote_scalar (v , target_len = target_rows ) for (k , v ) in cols .items ()}
488- return self .pd .DataFrame (cols )
520+ promoted_cols = {k : promote_scalar_to_array (v , target_len = target_rows ) for (k , v ) in cols .items ()}
521+ return self .pd .DataFrame (promoted_cols )
489522
490523 def add_data_frame_columns_to_data_frame_ (self , res , transient_new_frame ):
491524 """
@@ -542,14 +575,22 @@ def extend_step(self, op, *, data_map, narrow):
542575 """
543576 if op .node_name != "ExtendNode" :
544577 raise TypeError ("op was supposed to be a data_algebra.data_ops.ExtendNode" )
578+ res = self ._eval_value_source (op .sources [0 ], data_map = data_map , narrow = narrow )
579+ if res .shape [0 ] <= 0 :
580+ # special case out no-row frame
581+ incoming_col_set = set (res .columns )
582+ v_dict = {k : [] for k in res .columns }
583+ for k in op .ops .keys ():
584+ if k not in incoming_col_set :
585+ v_dict [k ] = []
586+ return self .pd .DataFrame (v_dict )
545587 window_situation = (
546588 op .windowed_situation
547589 or (len (op .partition_by ) > 0 )
548590 or (len (op .order_by ) > 0 )
549591 )
550592 if window_situation :
551593 op .check_extend_window_fns_ ()
552- res = self ._eval_value_source (op .sources [0 ], data_map = data_map , narrow = narrow )
553594 if not window_situation :
554595 with warnings .catch_warnings ():
555596 warnings .simplefilter (
@@ -569,7 +610,6 @@ def extend_step(self, op, *, data_map, narrow):
569610 col_list = col_list + [c ]
570611 col_set .add (c )
571612 order_cols = [c for c in col_list ] # must be partition by followed by order
572-
573613 for (k , opk ) in op .ops .items ():
574614 # assumes all args are column names or values, enforce this earlier
575615 if len (opk .args ) > 0 :
@@ -751,10 +791,15 @@ def project_step(self, op, *, data_map, narrow):
751791 # agg can return scalars, which then can't be made into a self.pd.DataFrame
752792 res = self .columns_to_frame_ (cols )
753793 res = res .reset_index (
754- drop = len (op .group_by ) < 1
794+ drop = ( len (op .group_by ) < 1 ) or ( res . shape [ 0 ] <= 0 )
755795 ) # grouping variables in the index
756796 missing_group_cols = set (op .group_by ) - set (res .columns )
757- assert len (missing_group_cols ) <= 0
797+ if res .shape [0 ] > 0 :
798+ if len (missing_group_cols ) != 0 :
799+ raise ValueError ("Missing column groups" )
800+ else :
801+ for g in missing_group_cols :
802+ res [g ] = []
758803 if "_data_table_temp_col" in res .columns :
759804 res = res .drop ("_data_table_temp_col" , axis = 1 , inplace = False )
760805 # double check shape is what we expect
@@ -873,6 +918,9 @@ def natural_join_step(self, op, *, data_map, narrow):
873918 )
874919 left = self ._eval_value_source (op .sources [0 ], data_map = data_map , narrow = narrow )
875920 right = self ._eval_value_source (op .sources [1 ], data_map = data_map , narrow = narrow )
921+ if (left .shape [0 ] == 0 ) and (right .shape [0 ] == 0 ):
922+ # pandas seems to not like this case
923+ return self .pd .DataFrame ({k : [] for k in op .columns_produced ()})
876924 common_cols = set ([c for c in left .columns ]).intersection (
877925 [c for c in right .columns ]
878926 )
0 commit comments