Skip to content

Commit d2c9fd8

Browse files
authored
fix: don't destroy dtypes in to_dataframe (#10705)
1 parent 7893348 commit d2c9fd8

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

xarray/core/coordinates.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index:
180180
np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
181181
for code in codes
182182
]
183-
level_list += [list(level) for level in levels]
183+
level_list += levels
184184
names += index.names
185185

186186
return pd.MultiIndex(
187-
levels=level_list, codes=[list(c) for c in code_list], names=names
187+
levels=level_list, # type: ignore[arg-type,unused-ignore]
188+
codes=[list(c) for c in code_list],
189+
names=names,
188190
)
189191

190192

xarray/tests/test_dataarray.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,6 +3528,34 @@ def test_to_dataframe_0length(self) -> None:
35283528
assert len(actual) == 0
35293529
assert_array_equal(actual.index.names, list("ABC"))
35303530

3531+
@pytest.mark.parametrize(
3532+
"x_dtype,y_dtype,v_dtype",
3533+
[
3534+
(np.uint32, np.float32, np.uint32),
3535+
(np.int16, np.float64, np.int64),
3536+
(np.uint8, np.float32, np.uint16),
3537+
(np.int32, np.float32, np.int8),
3538+
],
3539+
)
3540+
def test_to_dataframe_coord_dtypes_2d(self, x_dtype, y_dtype, v_dtype) -> None:
3541+
x = np.array([1], dtype=x_dtype)
3542+
y = np.array([1.0], dtype=y_dtype)
3543+
v = np.array([[42]], dtype=v_dtype)
3544+
3545+
da = DataArray(v, dims=["x", "y"], coords={"x": x, "y": y})
3546+
df = da.to_dataframe(name="v").reset_index()
3547+
3548+
# Check that coordinate dtypes are preserved
3549+
assert df["x"].dtype == np.dtype(x_dtype), (
3550+
f"x coord: expected {x_dtype}, got {df['x'].dtype}"
3551+
)
3552+
assert df["y"].dtype == np.dtype(y_dtype), (
3553+
f"y coord: expected {y_dtype}, got {df['y'].dtype}"
3554+
)
3555+
assert df["v"].dtype == np.dtype(v_dtype), (
3556+
f"v data: expected {v_dtype}, got {df['v'].dtype}"
3557+
)
3558+
35313559
@requires_dask_expr
35323560
@requires_dask
35333561
@pytest.mark.xfail(not has_dask_ge_2025_1_0, reason="dask-expr is broken")

0 commit comments

Comments
 (0)