1616# under the License. 
1717
1818import  datafusion 
19- import  pyarrow 
19+ import  pyarrow   as   pa 
2020import  pyarrow .compute 
2121from  datafusion  import  Accumulator , col , udaf 
2222
@@ -26,48 +26,44 @@ class MyAccumulator(Accumulator):
2626    Interface of a user-defined accumulation. 
2727    """ 
2828
29-     def  __init__ (self ):
30-         self ._sum  =  pyarrow .scalar (0.0 )
29+     def  __init__ (self )  ->   None :
30+         self ._sum  =  pa .scalar (0.0 )
3131
32-     def  update (self , values : pyarrow .Array ) ->  None :
32+     def  update (self , values : pa .Array ) ->  None :
3333        # not nice since pyarrow scalars can't be summed yet. This breaks on `None` 
34-         self ._sum  =  pyarrow .scalar (
35-             self ._sum .as_py () +  pyarrow .compute .sum (values ).as_py ()
36-         )
34+         self ._sum  =  pa .scalar (self ._sum .as_py () +  pa .compute .sum (values ).as_py ())
3735
38-     def  merge (self , states : pyarrow .Array ) ->  None :
36+     def  merge (self , states : pa .Array ) ->  None :
3937        # not nice since pyarrow scalars can't be summed yet. This breaks on `None` 
40-         self ._sum  =  pyarrow .scalar (
41-             self ._sum .as_py () +  pyarrow .compute .sum (states ).as_py ()
42-         )
38+         self ._sum  =  pa .scalar (self ._sum .as_py () +  pa .compute .sum (states ).as_py ())
4339
44-     def  state (self ) ->  pyarrow .Array :
45-         return  pyarrow .array ([self ._sum .as_py ()])
40+     def  state (self ) ->  pa .Array :
41+         return  pa .array ([self ._sum .as_py ()])
4642
47-     def  evaluate (self ) ->  pyarrow .Scalar :
43+     def  evaluate (self ) ->  pa .Scalar :
4844        return  self ._sum 
4945
5046
5147# create a context 
5248ctx  =  datafusion .SessionContext ()
5349
5450# create a RecordBatch and a new DataFrame from it 
55- batch  =  pyarrow .RecordBatch .from_arrays (
56-     [pyarrow .array ([1 , 2 , 3 ]), pyarrow .array ([4 , 5 , 6 ])],
51+ batch  =  pa .RecordBatch .from_arrays (
52+     [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
5753    names = ["a" , "b" ],
5854)
5955df  =  ctx .create_dataframe ([[batch ]])
6056
6157my_udaf  =  udaf (
6258    MyAccumulator ,
63-     pyarrow .float64 (),
64-     pyarrow .float64 (),
65-     [pyarrow .float64 ()],
59+     pa .float64 (),
60+     pa .float64 (),
61+     [pa .float64 ()],
6662    "stable" ,
6763)
6864
6965df  =  df .aggregate ([], [my_udaf (col ("a" ))])
7066
7167result  =  df .collect ()[0 ]
7268
73- assert  result .column (0 ) ==  pyarrow .array ([6.0 ])
69+ assert  result .column (0 ) ==  pa .array ([6.0 ])
0 commit comments