Skip to content

Add .prune() method to remove empty nodes from DataTree after filtering operations #10590

@aladinor

Description

@aladinor

Is your feature request related to a problem?

Propose adding a .prune() method to DataTree that removes empty nodes (nodes with no data variables or only empty data variables) from the tree structure. This would be particularly useful after filtering operations that can leave some nodes empty.

When performing selection operations on DataTrees, such as sel(), some nodes may end up empty if their coordinate ranges don't overlap with the selection criteria. Currently, these empty nodes remain in the tree structure, which can be confusing and problematic for downstream operations.

Reproducible Example

import xarray as xr
import numpy as np
import pandas as pd

# Create DataTree with different time ranges in each node
ds1 = xr.Dataset(
    {"foo": ("time", np.random.rand(5))},
    coords={"time": pd.date_range("2023-01-01", periods=5, freq="D")}
)
ds2 = xr.Dataset(
    {"var": ("time", np.random.rand(5))},
    coords={"time": pd.date_range("2023-01-04", periods=5, freq="D")}
)

dtree = xr.DataTree.from_dict({"a": ds1, "b": ds2})
print("Original tree:")
print(dtree)

# Filter - node 'b' becomes empty since its time range doesn't overlap
filtered = dtree.sel(time=slice("2023-01-01", "2023-01-03"))
print("\nAfter filtering:")
print(filtered)

Current Output:
Original tree:
<xarray.DataTree>
Group: /
├── Group: /aDimensions:  (time: 5)
│       Data variables: foo
└── Group: /b
        Dimensions:  (time: 5)
        Data variables: var

After filtering:
<xarray.DataTree>
Group: /
├── Group: /aDimensions:  (time: 3)
│       Data variables: foo
└── Group: /b          # <- Empty node remains
        Dimensions:  (time: 0)
        Data variables: var (empty)

Describe the solution you'd like

This .prune() method will allow us to remove the empty nodes.

# Proposed usage
pruned = filtered.prune()  # Remove empty nodes
print("\nAfter pruning:")
print(pruned)  # Should only show node 'a'

 Desired Output after .prune():
  <xarray.DataTree>
  Group: /
  └── Group: /a
          Dimensions:  (time: 3)
          Data variables: foo

Describe alternatives you've considered

Proposed API

  def prune(self, recursive=True):
      """
      Remove empty nodes from the DataTree.
      
      Parameters
      ----------
      recursive : bool, default True
          If True, recursively remove empty parent nodes that become 
          empty after pruning their children.
      
      Returns
      -------
      DataTree
          A new DataTree with empty nodes removed.
      """

Use Cases

  1. Post-filtering cleanup: Remove empty nodes after sel(), isel(), where() operations
  2. Data pipeline optimization: Clean up intermediate results in processing workflows
  3. Tree structure simplification: Remove nodes that have become empty due to data processing
  4. Memory optimization: Remove unnecessary empty nodes to reduce the memory footprint

Definition of "Empty Node"

A node would be considered empty if:

  • It has no data variables, OR
  • All its data variables have zero-sized dimensions, OR
  • All its data variables contain only NaN/null values (optional behavior)

Implementation Considerations

  • Should preserve coordinate information in remaining nodes
  • Should handle parent nodes that become empty after children are pruned
  • Should work with both named and unnamed nodes
  • Should maintain tree structure integrity

Additional context

CC @TomNicholas

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions