Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4613,20 +4613,17 @@ def expand_dims(
# or iterables.
for k, v in dim.items():
if hasattr(v, "__iter__"):
# If the value for the new dimension is an iterable, then
# save the coordinates to the variables dict, and set the
# value within the dim dict to the length of the iterable
# for later use.

if create_index_for_new_dim:
index = PandasIndex(v, k)
indexes[k] = index
name_and_new_1d_var = index.create_variables()
else:
name_and_new_1d_var = {k: Variable(data=v, dims=k)}

variables.update(name_and_new_1d_var)
coord_names.add(k)
dim[k] = variables[k].size

elif isinstance(v, int):
pass # Do nothing if the dimensions value is just an int
else:
Expand Down
10 changes: 10 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,16 @@ def __init__(
coord_dtype = index.dtype
else:
coord_dtype = get_valid_numpy_dtype(index)
if coord_dtype == object and index.dtype == object:
inferred = getattr(index, "inferred_type", None)
if inferred in ("string", "unicode"):
coord_dtype = np.dtype(str)
else:
data = index.to_numpy(dtype=object, copy=False)
if data.size and all(
isinstance(x, (str, np.str_)) for x in data.ravel()
):
coord_dtype = np.asarray(data, dtype=str).dtype
self.coord_dtype = coord_dtype

def _replace(self, index, dim=None, coord_dtype=None):
Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,17 @@ def test_expand_dims_with_greater_dim_size(self) -> None:
).drop_vars("dim_0")
assert_identical(other_way_expected, other_way)

def test_expand_dims_infers_string_dtype_for_new_dim_coords(self) -> None:
da = xr.DataArray(10).expand_dims({"band": ["b1", "b2", "b3"]})
assert da.coords["band"].dtype.kind == "U"

def test_expand_dims_string_coord_does_not_poison_concat(self) -> None:
da_string = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": ["a", "b", "c"]})
da_from_expand_dims = xr.DataArray(10).expand_dims({"x": ["d", "e", "f"]})

result = xr.concat([da_string, da_from_expand_dims], dim="x")
assert result.coords["x"].dtype.kind == "U"

def test_set_index(self) -> None:
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] # type: ignore[arg-type,unused-ignore] # pandas-stubs varies
coords = {idx.name: ("x", idx) for idx in indexes}
Expand Down
Loading