|
16 | 16 | import pandas as pd
|
17 | 17 | from numpy import all as array_all # noqa
|
18 | 18 | from numpy import any as array_any # noqa
|
| 19 | +from numpy import around # noqa |
19 | 20 | from numpy import zeros_like # noqa
|
20 |
| -from numpy import around, broadcast_to # noqa |
21 | 21 | from numpy import concatenate as _concatenate
|
22 | 22 | from numpy import ( # noqa
|
23 | 23 | einsum,
|
@@ -207,6 +207,11 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
|
207 | 207 | return [astype(x, out_type, copy=False) for x in arrays]
|
208 | 208 |
|
209 | 209 |
|
| 210 | +def broadcast_to(array, shape): |
| 211 | + xp = get_array_namespace(array) |
| 212 | + return xp.broadcast_to(array, shape) |
| 213 | + |
| 214 | + |
210 | 215 | def lazy_array_equiv(arr1, arr2):
|
211 | 216 | """Like array_equal, but doesn't actually compare values.
|
212 | 217 | Returns True when arr1, arr2 identical or their dask tokens are equal.
|
@@ -311,6 +316,9 @@ def fillna(data, other):
|
311 | 316 |
|
312 | 317 | def concatenate(arrays, axis=0):
|
313 | 318 | """concatenate() with better dtype promotion rules."""
|
| 319 | + if hasattr(arrays[0], "__array_namespace__"): |
| 320 | + xp = get_array_namespace(arrays[0]) |
| 321 | + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) |
314 | 322 | return _concatenate(as_shared_dtype(arrays), axis=axis)
|
315 | 323 |
|
316 | 324 |
|
|
0 commit comments