8
8
TYPE_CHECKING ,
9
9
Any ,
10
10
Dict ,
11
+ List ,
11
12
Literal ,
12
13
Optional ,
13
14
Protocol ,
23
24
import numpy as np
24
25
25
26
from ._typing import CNumericPtr , DataType , NumpyDType , NumpyOrCupy
26
- from .compat import import_cupy , lazy_isinstance
27
+ from .compat import import_cupy , import_pyarrow , lazy_isinstance
27
28
28
29
if TYPE_CHECKING :
29
30
import pandas as pd
@@ -69,7 +70,11 @@ def shape(self) -> Tuple[int, int]:
69
70
70
71
def array_hasobject (data : DataType ) -> bool :
71
72
"""Whether the numpy array has object dtype."""
72
- return hasattr (data .dtype , "hasobject" ) and data .dtype .hasobject
73
+ return (
74
+ hasattr (data , "dtype" )
75
+ and hasattr (data .dtype , "hasobject" )
76
+ and data .dtype .hasobject
77
+ )
73
78
74
79
75
80
def cuda_array_interface_dict (data : _CudaArrayLikeArg ) -> ArrayInf :
@@ -180,7 +185,7 @@ def is_arrow_dict(data: Any) -> TypeGuard["pa.DictionaryArray"]:
180
185
return lazy_isinstance (data , "pyarrow.lib" , "DictionaryArray" )
181
186
182
187
183
- class PdCatAccessor (Protocol ):
188
+ class DfCatAccessor (Protocol ):
184
189
"""Protocol for pandas cat accessor."""
185
190
186
191
@property
@@ -202,7 +207,7 @@ def to_arrow( # pylint: disable=missing-function-docstring
202
207
def __cuda_array_interface__ (self ) -> ArrayInf : ...
203
208
204
209
205
- def _is_pd_cat (data : Any ) -> TypeGuard [PdCatAccessor ]:
210
+ def _is_df_cat (data : Any ) -> TypeGuard [DfCatAccessor ]:
206
211
# Test pd.Series.cat, not pd.Series
207
212
return hasattr (data , "categories" ) and hasattr (data , "codes" )
208
213
@@ -234,6 +239,67 @@ def npstr_to_arrow_strarr(strarr: np.ndarray) -> Tuple[np.ndarray, str]:
234
239
return offsets .astype (np .int32 ), values
235
240
236
241
242
+ def _arrow_cat_inf ( # pylint: disable=too-many-locals
243
+ cats : "pa.StringArray" ,
244
+ codes : Union [_ArrayLikeArg , _CudaArrayLikeArg , "pa.IntegerArray" ],
245
+ ) -> Tuple [StringArray , ArrayInf , Tuple ]:
246
+ if not TYPE_CHECKING :
247
+ pa = import_pyarrow ()
248
+
249
+ # FIXME(jiamingy): Account for offset, need to find an implementation that returns
250
+ # offset > 0
251
+ assert cats .offset == 0
252
+ buffers : List [pa .Buffer ] = cats .buffers ()
253
+ mask , offset , data = buffers
254
+ assert offset .is_cpu
255
+
256
+ off_len = len (cats ) + 1
257
+ if offset .size != off_len * (np .iinfo (np .int32 ).bits / 8 ):
258
+ raise TypeError ("Arrow dictionary type offsets is required to be 32 bit." )
259
+
260
+ joffset : ArrayInf = {
261
+ "data" : (offset .address , True ),
262
+ "typestr" : "<i4" ,
263
+ "version" : 3 ,
264
+ "strides" : None ,
265
+ "shape" : (off_len ,),
266
+ "mask" : None ,
267
+ }
268
+
269
+ def make_buf_inf (buf : pa .Buffer , typestr : str ) -> ArrayInf :
270
+ return {
271
+ "data" : (buf .address , True ),
272
+ "typestr" : typestr ,
273
+ "version" : 3 ,
274
+ "strides" : None ,
275
+ "shape" : (buf .size ,),
276
+ "mask" : None ,
277
+ }
278
+
279
+ jdata = make_buf_inf (data , "<i1" )
280
+ # Categories should not have missing values.
281
+ assert mask is None
282
+
283
+ jnames : StringArray = {"offsets" : joffset , "values" : jdata }
284
+
285
+ def make_array_inf (
286
+ array : Any ,
287
+ ) -> Tuple [ArrayInf , Optional [Tuple [pa .Buffer , pa .Buffer ]]]:
288
+ """Helper for handling categorical codes."""
289
+ # Handle cuDF data
290
+ if hasattr (array , "__cuda_array_interface__" ):
291
+ inf = cuda_array_interface_dict (array )
292
+ return inf , None
293
+
294
+ # Other types (like arrow itself) are not yet supported.
295
+ raise TypeError ("Invalid input type." )
296
+
297
+ cats_tmp = (mask , offset , data )
298
+ jcodes , codes_tmp = make_array_inf (codes )
299
+
300
+ return jnames , jcodes , (cats_tmp , codes_tmp )
301
+
302
+
237
303
def _ensure_np_dtype (
238
304
data : DataType , dtype : Optional [NumpyDType ]
239
305
) -> Tuple [np .ndarray , Optional [NumpyDType ]]:
@@ -252,7 +318,7 @@ def array_interface_dict(data: np.ndarray) -> ArrayInf: ...
252
318
253
319
@overload
254
320
def array_interface_dict (
255
- data : PdCatAccessor ,
321
+ data : DfCatAccessor ,
256
322
) -> Tuple [StringArray , ArrayInf , Tuple ]: ...
257
323
258
324
@@ -263,11 +329,11 @@ def array_interface_dict(
263
329
264
330
265
331
def array_interface_dict ( # pylint: disable=too-many-locals
266
- data : Union [np .ndarray , PdCatAccessor ],
332
+ data : Union [np .ndarray , DfCatAccessor ],
267
333
) -> Union [ArrayInf , Tuple [StringArray , ArrayInf , Optional [Tuple ]]]:
268
334
"""Returns an array interface from the input."""
269
335
# Handle categorical values
270
- if _is_pd_cat (data ):
336
+ if _is_df_cat (data ):
271
337
cats = data .categories
272
338
# pandas uses -1 to represent missing values for categorical features
273
339
codes = data .codes .replace (- 1 , np .nan )
@@ -287,6 +353,7 @@ def array_interface_dict( # pylint: disable=too-many-locals
287
353
name_offsets , _ = _ensure_np_dtype (name_offsets , np .int32 )
288
354
joffsets = array_interface_dict (name_offsets )
289
355
bvalues = name_values .encode ("utf-8" )
356
+
290
357
ptr = ctypes .c_void_p .from_buffer (ctypes .c_char_p (bvalues )).value
291
358
assert ptr is not None
292
359
@@ -335,3 +402,20 @@ def check_cudf_meta(data: _CudaArrayLikeArg, field: str) -> None:
335
402
and data .__cuda_array_interface__ ["mask" ] is not None
336
403
):
337
404
raise ValueError (f"Missing value is not allowed for: { field } " )
405
+
406
+
407
+ def cudf_cat_inf (
408
+ cats : DfCatAccessor , codes : "pd.Series"
409
+ ) -> Tuple [Union [ArrayInf , StringArray ], ArrayInf , Tuple ]:
410
+ """Obtain the cuda array interface for cuDF categories."""
411
+ cp = import_cupy ()
412
+ is_num_idx = cp .issubdtype (cats .dtype , cp .floating ) or cp .issubdtype (
413
+ cats .dtype , cp .integer
414
+ )
415
+ if is_num_idx :
416
+ cats_ainf = cats .__cuda_array_interface__
417
+ codes_ainf = cuda_array_interface_dict (codes )
418
+ return cats_ainf , codes_ainf , (cats , codes )
419
+
420
+ joffset , jdata , buf = _arrow_cat_inf (cats .to_arrow (), codes )
421
+ return joffset , jdata , buf
0 commit comments