@@ -41,61 +41,53 @@ def __init__(
4141 self ._dropna = dropna
4242 self ._use_arrow_dtype = use_arrow_dtype
4343
44- @staticmethod
45- def _drop_duplicates_to_arrow (v , explode = False ):
44+ def _drop_duplicates (self , xdf , value , explode = False ):
4645 if explode :
47- v = v .explode ()
48- try :
49- return ArrowListArray ([v .drop_duplicates ().to_numpy ()])
50- except pa .ArrowInvalid :
51- # fallback due to diverse dtypes
52- return [v .drop_duplicates ().to_list ()]
46+ value = value .explode ()
47+
48+ if not self ._use_arrow_dtype or xdf is cudf :
49+ return [value .drop_duplicates ().to_list ()]
50+ else :
51+ try :
52+ return ArrowListArray ([value .drop_duplicates ().to_numpy ()])
53+ except pa .ArrowInvalid :
54+ # fallback due to diverse dtypes
55+ return [value .drop_duplicates ().to_list ()]
5356
5457 def pre (self , in_data ): # noqa: W0221 # pylint: disable=arguments-differ
5558 xdf = cudf if self .is_gpu () else pd
5659 if isinstance (in_data , xdf .Series ):
57- unique_values = in_data . drop_duplicates ( )
60+ unique_values = self . _drop_duplicates ( xdf , in_data )
5861 return xdf .Series (unique_values , name = in_data .name )
5962 else :
6063 if self ._axis == 0 :
6164 data = dict ()
6265 for d , v in in_data .iteritems ():
63- if not self ._use_arrow_dtype or xdf is cudf :
64- data [d ] = [v .drop_duplicates ().to_list ()]
65- else :
66- data [d ] = self ._drop_duplicates_to_arrow (v )
66+ data [d ] = self ._drop_duplicates (xdf , v )
6767 df = xdf .DataFrame (data )
6868 else :
6969 df = xdf .DataFrame (columns = [0 ])
7070 for d , v in in_data .iterrows ():
71- if not self ._use_arrow_dtype or xdf is cudf :
72- df .loc [d ] = [v .drop_duplicates ().to_list ()]
73- else :
74- df .loc [d ] = self ._drop_duplicates_to_arrow (v )
71+ df .loc [d ] = self ._drop_duplicates (xdf , v )
7572 return df
7673
7774 def agg (self , in_data ): # noqa: W0221 # pylint: disable=arguments-differ
7875 xdf = cudf if self .is_gpu () else pd
7976 if isinstance (in_data , xdf .Series ):
80- unique_values = in_data . explode (). drop_duplicates ( )
77+ unique_values = self . _drop_duplicates ( xdf , in_data , explode = True )
8178 return xdf .Series (unique_values , name = in_data .name )
8279 else :
8380 if self ._axis == 0 :
8481 data = dict ()
8582 for d , v in in_data .iteritems ():
86- if not self ._use_arrow_dtype or xdf is cudf :
87- data [d ] = [v .explode ().drop_duplicates ().to_list ()]
88- else :
83+ if self ._use_arrow_dtype and xdf is not cudf :
8984 v = pd .Series (v .to_numpy ())
90- data [d ] = self ._drop_duplicates_to_arrow ( v , explode = True )
85+ data [d ] = self ._drop_duplicates ( xdf , v , explode = True )
9186 df = xdf .DataFrame (data )
9287 else :
9388 df = xdf .DataFrame (columns = [0 ])
9489 for d , v in in_data .iterrows ():
95- if not self ._use_arrow_dtype or xdf is cudf :
96- df .loc [d ] = [v .explode ().drop_duplicates ().to_list ()]
97- else :
98- df .loc [d ] = self ._drop_duplicates_to_arrow (v , explode = True )
90+ df .loc [d ] = self ._drop_duplicates (xdf , v , explode = True )
9991 return df
10092
10193 def post (self , in_data ): # noqa: W0221 # pylint: disable=arguments-differ
0 commit comments