Skip to content

Commit fb67358

Browse files
authored
coords: retain str dtype (#4759)
* coords: retain str dtype * fix doctests * update what's new * fix multiindex repr * rename function * ensure minimum str dtype * fix EOL spaces
1 parent f52a95c commit fb67358

File tree

13 files changed

+193
-12
lines changed

13 files changed

+193
-12
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ Bug fixes
6666
By `Anderson Banihirwe <https://github.com/andersy005>`_
6767
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
6868
By `Alessandro Amici <https://github.com/alexamici>`_
69+
- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations,
70+
e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype
71+
(:issue:`2658` and :issue:`4543`) by `Mathias Hauser <https://github.com/mathause>`_.
6972
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
7073
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
7174
- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`).

xarray/core/alignment.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from . import dtypes, utils
2121
from .indexing import get_indexer_nd
22-
from .utils import is_dict_like, is_full_slice
22+
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str
2323
from .variable import IndexVariable, Variable
2424

2525
if TYPE_CHECKING:
@@ -278,10 +278,12 @@ def align(
278278
return (obj.copy(deep=copy),)
279279

280280
all_indexes = defaultdict(list)
281+
all_coords = defaultdict(list)
281282
unlabeled_dim_sizes = defaultdict(set)
282283
for obj in objects:
283284
for dim in obj.dims:
284285
if dim not in exclude:
286+
all_coords[dim].append(obj.coords[dim])
285287
try:
286288
index = obj.indexes[dim]
287289
except KeyError:
@@ -306,7 +308,7 @@ def align(
306308
any(not index.equals(other) for other in matching_indexes)
307309
or dim in unlabeled_dim_sizes
308310
):
309-
joined_indexes[dim] = index
311+
joined_indexes[dim] = indexes[dim]
310312
else:
311313
if (
312314
any(
@@ -318,9 +320,11 @@ def align(
318320
if join == "exact":
319321
raise ValueError(f"indexes along dimension {dim!r} are not equal")
320322
index = joiner(matching_indexes)
323+
# make sure str coords are not cast to object
324+
index = maybe_coerce_to_str(index, all_coords[dim])
321325
joined_indexes[dim] = index
322326
else:
323-
index = matching_indexes[0]
327+
index = all_coords[dim][0]
324328

325329
if dim in unlabeled_dim_sizes:
326330
unlabeled_sizes = unlabeled_dim_sizes[dim]
@@ -583,7 +587,7 @@ def reindex_variables(
583587
args: tuple = (var.attrs, var.encoding)
584588
else:
585589
args = ()
586-
reindexed[dim] = IndexVariable((dim,), target, *args)
590+
reindexed[dim] = IndexVariable((dim,), indexers[dim], *args)
587591

588592
for dim in sizes:
589593
if dim not in indexes and dim in indexers:

xarray/core/concat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def concat(
187187
array([[0, 1, 2],
188188
[3, 4, 5]])
189189
Coordinates:
190-
* x (x) object 'a' 'b'
190+
* x (x) <U1 'a' 'b'
191191
* y (y) int64 10 20 30
192192
193193
>>> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim")
@@ -503,7 +503,7 @@ def ensure_common_dims(vars):
503503
for k in datasets[0].variables:
504504
if k in concat_over:
505505
try:
506-
vars = ensure_common_dims([ds.variables[k] for ds in datasets])
506+
vars = ensure_common_dims([ds[k].variable for ds in datasets])
507507
except KeyError:
508508
raise ValueError("%r is not present in all datasets." % k)
509509
combined = concat_vars(vars, dim, positions)

xarray/core/dataarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,8 +1325,8 @@ def broadcast_like(
13251325
[ 2.2408932 , 1.86755799, -0.97727788],
13261326
[ nan, nan, nan]])
13271327
Coordinates:
1328-
* x (x) object 'a' 'b' 'c'
1329-
* y (y) object 'a' 'b' 'c'
1328+
* x (x) <U1 'a' 'b' 'c'
1329+
* y (y) <U1 'a' 'b' 'c'
13301330
"""
13311331
if exclude is None:
13321332
exclude = set()

xarray/core/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,7 +2565,7 @@ def reindex(
25652565
<xarray.Dataset>
25662566
Dimensions: (station: 4)
25672567
Coordinates:
2568-
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
2568+
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
25692569
Data variables:
25702570
temperature (station) float64 10.98 nan 12.06 nan
25712571
pressure (station) float64 211.8 nan 218.8 nan
@@ -2576,7 +2576,7 @@ def reindex(
25762576
<xarray.Dataset>
25772577
Dimensions: (station: 4)
25782578
Coordinates:
2579-
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
2579+
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
25802580
Data variables:
25812581
temperature (station) float64 10.98 0.0 12.06 0.0
25822582
pressure (station) float64 211.8 0.0 218.8 0.0
@@ -2589,7 +2589,7 @@ def reindex(
25892589
<xarray.Dataset>
25902590
Dimensions: (station: 4)
25912591
Coordinates:
2592-
* station (station) object 'boston' 'austin' 'seattle' 'lincoln'
2592+
* station (station) <U7 'boston' 'austin' 'seattle' 'lincoln'
25932593
Data variables:
25942594
temperature (station) float64 10.98 0.0 12.06 0.0
25952595
pressure (station) float64 211.8 100.0 218.8 100.0

xarray/core/merge.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,11 @@ def dataset_update_method(
930930
if coord_names:
931931
other[key] = value.drop_vars(coord_names)
932932

933+
# use ds.coords and not ds.indexes, else str coords are cast to object
934+
indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()}
933935
return merge_core(
934936
[dataset, other],
935937
priority_arg=1,
936-
indexes=dataset.indexes,
938+
indexes=indexes,
937939
combine_attrs="override",
938940
)

xarray/core/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import numpy as np
3232
import pandas as pd
3333

34+
from . import dtypes
35+
3436
K = TypeVar("K")
3537
V = TypeVar("V")
3638
T = TypeVar("T")
@@ -76,6 +78,23 @@ def maybe_cast_to_coords_dtype(label, coords_dtype):
7678
return label
7779

7880

81+
def maybe_coerce_to_str(index, original_coords):
82+
"""maybe coerce a pandas Index back to a nunpy array of type str
83+
84+
pd.Index uses object-dtype to store str - try to avoid this for coords
85+
"""
86+
87+
try:
88+
result_type = dtypes.result_type(*original_coords)
89+
except TypeError:
90+
pass
91+
else:
92+
if result_type.kind in "SU":
93+
index = np.asarray(index, dtype=result_type.type)
94+
95+
return index
96+
97+
7998
def safe_cast_to_index(array: Any) -> pd.Index:
8099
"""Given an array, safely cast it to a pandas.Index.
81100

xarray/core/variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ensure_us_time_resolution,
4949
infix_dims,
5050
is_duck_array,
51+
maybe_coerce_to_str,
5152
)
5253

5354
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
@@ -2523,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False):
25232524
indices = nputils.inverse_permutation(np.concatenate(positions))
25242525
data = data.take(indices)
25252526

2527+
# keep as str if possible as pandas.Index uses object (converts to numpy array)
2528+
data = maybe_coerce_to_str(data, variables)
2529+
25262530
attrs = dict(first_var.attrs)
25272531
if not shortcut:
25282532
for var in variables:

xarray/tests/test_concat.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,30 @@ def test_concat_fill_value(self, fill_value):
376376
actual = concat(datasets, dim="t", fill_value=fill_value)
377377
assert_identical(actual, expected)
378378

379+
@pytest.mark.parametrize("dtype", [str, bytes])
380+
@pytest.mark.parametrize("dim", ["x1", "x2"])
381+
def test_concat_str_dtype(self, dtype, dim):
382+
383+
data = np.arange(4).reshape([2, 2])
384+
385+
da1 = Dataset(
386+
{
387+
"data": (["x1", "x2"], data),
388+
"x1": [0, 1],
389+
"x2": np.array(["a", "b"], dtype=dtype),
390+
}
391+
)
392+
da2 = Dataset(
393+
{
394+
"data": (["x1", "x2"], data),
395+
"x1": np.array([1, 2]),
396+
"x2": np.array(["c", "d"], dtype=dtype),
397+
}
398+
)
399+
actual = concat([da1, da2], dim=dim)
400+
401+
assert np.issubdtype(actual.x2.dtype, dtype)
402+
379403

380404
class TestConcatDataArray:
381405
def test_concat(self):
@@ -525,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self):
525549
actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs)
526550
assert_identical(actual, expected[combine_attrs])
527551

552+
@pytest.mark.parametrize("dtype", [str, bytes])
553+
@pytest.mark.parametrize("dim", ["x1", "x2"])
554+
def test_concat_str_dtype(self, dtype, dim):
555+
556+
data = np.arange(4).reshape([2, 2])
557+
558+
da1 = DataArray(
559+
data=data,
560+
dims=["x1", "x2"],
561+
coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)},
562+
)
563+
da2 = DataArray(
564+
data=data,
565+
dims=["x1", "x2"],
566+
coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)},
567+
)
568+
actual = concat([da1, da2], dim=dim)
569+
570+
assert np.issubdtype(actual.x2.dtype, dtype)
571+
528572

529573
@pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {}))
530574
@pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {}))

xarray/tests/test_dataarray.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,19 @@ def test_reindex_fill_value(self, fill_value):
15681568
)
15691569
assert_identical(expected, actual)
15701570

1571+
@pytest.mark.parametrize("dtype", [str, bytes])
1572+
def test_reindex_str_dtype(self, dtype):
1573+
1574+
data = DataArray(
1575+
[1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)}
1576+
)
1577+
1578+
actual = data.reindex(x=data.x)
1579+
expected = data
1580+
1581+
assert_identical(expected, actual)
1582+
assert actual.dtype == expected.dtype
1583+
15711584
def test_rename(self):
15721585
renamed = self.dv.rename("bar")
15731586
assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"}))
@@ -3435,6 +3448,26 @@ def test_align_without_indexes_errors(self):
34353448
DataArray([1, 2], coords=[("x", [0, 1])]),
34363449
)
34373450

3451+
def test_align_str_dtype(self):
3452+
3453+
a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]})
3454+
b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]})
3455+
3456+
expected_a = DataArray(
3457+
[0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]}
3458+
)
3459+
expected_b = DataArray(
3460+
[np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]}
3461+
)
3462+
3463+
actual_a, actual_b = xr.align(a, b, join="outer")
3464+
3465+
assert_identical(expected_a, actual_a)
3466+
assert expected_a.x.dtype == actual_a.x.dtype
3467+
3468+
assert_identical(expected_b, actual_b)
3469+
assert expected_b.x.dtype == actual_b.x.dtype
3470+
34383471
def test_broadcast_arrays(self):
34393472
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
34403473
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")

0 commit comments

Comments
 (0)