Skip to content

Commit 54d1d72

Browse files
authored
[backport] Use array interface for testing numpy arrays. (dmlc#9602) (dmlc#9635)
1 parent 032bcc5 commit 54d1d72

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

python-package/xgboost/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,7 @@ def inplace_predict(
23992399
_is_cudf_df,
24002400
_is_cupy_array,
24012401
_is_list,
2402+
_is_np_array_like,
24022403
_is_pandas_df,
24032404
_is_pandas_series,
24042405
_is_tuple,
@@ -2428,7 +2429,7 @@ def inplace_predict(
24282429
f"got {data.shape[1]}"
24292430
)
24302431

2431-
if isinstance(data, np.ndarray):
2432+
if _is_np_array_like(data):
24322433
from .data import _ensure_np_dtype
24332434

24342435
data, _ = _ensure_np_dtype(data, data.dtype)

python-package/xgboost/data.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def _is_scipy_coo(data: DataType) -> bool:
164164
return isinstance(data, scipy.sparse.coo_matrix)
165165

166166

167-
def _is_numpy_array(data: DataType) -> bool:
168-
return isinstance(data, (np.ndarray, np.matrix))
167+
def _is_np_array_like(data: DataType) -> bool:
168+
return hasattr(data, "__array_interface__")
169169

170170

171171
def _ensure_np_dtype(
@@ -1051,7 +1051,7 @@ def dispatch_data_backend(
10511051
return _from_scipy_csr(
10521052
data.tocsr(), missing, threads, feature_names, feature_types
10531053
)
1054-
if _is_numpy_array(data):
1054+
if _is_np_array_like(data):
10551055
return _from_numpy_array(
10561056
data, missing, threads, feature_names, feature_types, data_split_mode
10571057
)
@@ -1194,7 +1194,7 @@ def dispatch_meta_backend(
11941194
if _is_tuple(data):
11951195
_meta_from_tuple(data, name, dtype, handle)
11961196
return
1197-
if _is_numpy_array(data):
1197+
if _is_np_array_like(data):
11981198
_meta_from_numpy(data, name, dtype, handle)
11991199
return
12001200
if _is_pandas_df(data):
@@ -1281,7 +1281,7 @@ def _proxy_transform(
12811281
return _transform_dlpack(data), None, feature_names, feature_types
12821282
if _is_list(data) or _is_tuple(data):
12831283
data = np.array(data)
1284-
if _is_numpy_array(data):
1284+
if _is_np_array_like(data):
12851285
data, _ = _ensure_np_dtype(data, data.dtype)
12861286
return data, None, feature_names, feature_types
12871287
if _is_scipy_csr(data):
@@ -1331,7 +1331,7 @@ def dispatch_proxy_set_data(
13311331
if not allow_host:
13321332
raise err
13331333

1334-
if _is_numpy_array(data):
1334+
if _is_np_array_like(data):
13351335
_check_data_shape(data)
13361336
proxy._set_data_from_array(data) # pylint: disable=W0212
13371337
return

0 commit comments

Comments
 (0)