|
15 | 15 | as_dataarray, |
16 | 16 | assign_multiindex_safe, |
17 | 17 | best_int, |
| 18 | + get_dims_with_index_levels, |
18 | 19 | iterate_slices, |
19 | 20 | ) |
20 | 21 |
|
@@ -529,3 +530,77 @@ def test_iterate_slices_no_slice_dims(): |
529 | 530 | for s in slices: |
530 | 531 | assert isinstance(s, xr.Dataset) |
531 | 532 | assert set(s.dims) == set(ds.dims) |
| 533 | + |
| 534 | + |
| 535 | +def test_get_dims_with_index_levels(): |
| 536 | + # Create test data |
| 537 | + |
| 538 | + # Case 1: Simple dataset with regular dimensions |
| 539 | + ds1 = xr.Dataset( |
| 540 | + {"temp": (("time", "lat"), np.random.rand(3, 2))}, # noqa: NPY002 |
| 541 | + coords={"time": pd.date_range("2024-01-01", periods=3), "lat": [0, 1]}, |
| 542 | + ) |
| 543 | + |
| 544 | + # Case 2: Dataset with a multi-index dimension |
| 545 | + stations_index = pd.MultiIndex.from_product( |
| 546 | + [["USA", "Canada"], ["NYC", "Toronto"]], names=["country", "city"] |
| 547 | + ) |
| 548 | + stations_coords = xr.Coordinates.from_pandas_multiindex(stations_index, "station") |
| 549 | + ds2 = xr.Dataset( |
| 550 | + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 |
| 551 | + coords={"time": pd.date_range("2024-01-01", periods=3), **stations_coords}, |
| 552 | + ) |
| 553 | + |
| 554 | + # Case 3: Dataset with unnamed multi-index levels |
| 555 | + unnamed_stations_index = pd.MultiIndex.from_product( |
| 556 | + [["USA", "Canada"], ["NYC", "Toronto"]] |
| 557 | + ) |
| 558 | + unnamed_stations_coords = xr.Coordinates.from_pandas_multiindex( |
| 559 | + unnamed_stations_index, "station" |
| 560 | + ) |
| 561 | + ds3 = xr.Dataset( |
| 562 | + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 |
| 563 | + coords={ |
| 564 | + "time": pd.date_range("2024-01-01", periods=3), |
| 565 | + **unnamed_stations_coords, |
| 566 | + }, |
| 567 | + ) |
| 568 | + |
| 569 | + # Case 4: Dataset with multiple multi-indexed dimensions |
| 570 | + locations_index = pd.MultiIndex.from_product( |
| 571 | + [["North", "South"], ["A", "B"]], names=["region", "site"] |
| 572 | + ) |
| 573 | + locations_coords = xr.Coordinates.from_pandas_multiindex( |
| 574 | + locations_index, "location" |
| 575 | + ) |
| 576 | + |
| 577 | + ds4 = xr.Dataset( |
| 578 | + {"temp": (("time", "station", "location"), np.random.rand(2, 4, 4))}, # noqa: NPY002 |
| 579 | + coords={ |
| 580 | + "time": pd.date_range("2024-01-01", periods=2), |
| 581 | + **stations_coords, |
| 582 | + **locations_coords, |
| 583 | + }, |
| 584 | + ) |
| 585 | + |
| 586 | + # Run tests |
| 587 | + |
| 588 | + # Test case 1: Regular dimensions |
| 589 | + assert get_dims_with_index_levels(ds1) == ["time", "lat"] |
| 590 | + |
| 591 | + # Test case 2: Named multi-index |
| 592 | + assert get_dims_with_index_levels(ds2) == ["time", "station (country, city)"] |
| 593 | + |
| 594 | + # Test case 3: Unnamed multi-index |
| 595 | + assert get_dims_with_index_levels(ds3) == [ |
| 596 | + "time", |
| 597 | + "station (station_level_0, station_level_1)", |
| 598 | + ] |
| 599 | + |
| 600 | + # Test case 4: Multiple multi-indices |
| 601 | + expected = ["time", "station (country, city)", "location (region, site)"] |
| 602 | + assert get_dims_with_index_levels(ds4) == expected |
| 603 | + |
| 604 | + # Test case 5: Empty dataset |
| 605 | + ds5 = xr.Dataset() |
| 606 | + assert get_dims_with_index_levels(ds5) == [] |
0 commit comments