1616# under the License.
1717
1818import datafusion
19- import pyarrow
19+ import pyarrow as pa
2020import pyarrow .compute
2121from datafusion import Accumulator , col , udaf
2222
@@ -26,48 +26,44 @@ class MyAccumulator(Accumulator):
2626 Interface of a user-defined accumulation.
2727 """
2828
29- def __init__ (self ):
30- self ._sum = pyarrow .scalar (0.0 )
29+ def __init__ (self ) -> None :
30+ self ._sum = pa .scalar (0.0 )
3131
32- def update (self , values : pyarrow .Array ) -> None :
32+ def update (self , values : pa .Array ) -> None :
3333 # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
34- self ._sum = pyarrow .scalar (
35- self ._sum .as_py () + pyarrow .compute .sum (values ).as_py ()
36- )
34+ self ._sum = pa .scalar (self ._sum .as_py () + pa .compute .sum (values ).as_py ())
3735
38- def merge (self , states : pyarrow .Array ) -> None :
36+ def merge (self , states : pa .Array ) -> None :
3937 # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
40- self ._sum = pyarrow .scalar (
41- self ._sum .as_py () + pyarrow .compute .sum (states ).as_py ()
42- )
38+ self ._sum = pa .scalar (self ._sum .as_py () + pa .compute .sum (states ).as_py ())
4339
44- def state (self ) -> pyarrow .Array :
45- return pyarrow .array ([self ._sum .as_py ()])
40+ def state (self ) -> pa .Array :
41+ return pa .array ([self ._sum .as_py ()])
4642
47- def evaluate (self ) -> pyarrow .Scalar :
43+ def evaluate (self ) -> pa .Scalar :
4844 return self ._sum
4945
5046
5147# create a context
5248ctx = datafusion .SessionContext ()
5349
5450# create a RecordBatch and a new DataFrame from it
55- batch = pyarrow .RecordBatch .from_arrays (
56- [pyarrow .array ([1 , 2 , 3 ]), pyarrow .array ([4 , 5 , 6 ])],
51+ batch = pa .RecordBatch .from_arrays (
52+ [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
5753 names = ["a" , "b" ],
5854)
5955df = ctx .create_dataframe ([[batch ]])
6056
6157my_udaf = udaf (
6258 MyAccumulator ,
63- pyarrow .float64 (),
64- pyarrow .float64 (),
65- [pyarrow .float64 ()],
59+ pa .float64 (),
60+ pa .float64 (),
61+ [pa .float64 ()],
6662 "stable" ,
6763)
6864
6965df = df .aggregate ([], [my_udaf (col ("a" ))])
7066
7167result = df .collect ()[0 ]
7268
73- assert result .column (0 ) == pyarrow .array ([6.0 ])
69+ assert result .column (0 ) == pa .array ([6.0 ])
0 commit comments