diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0c8b0ba..e384deb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -34,6 +34,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Setup Graphviz + uses: ts-graphviz/setup-graphviz@v2 + - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh diff --git a/cubed_xarray/cubedmanager.py b/cubed_xarray/cubedmanager.py index 52ca284..0292aa3 100644 --- a/cubed_xarray/cubedmanager.py +++ b/cubed_xarray/cubedmanager.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Union import numpy as np +import xarray as xr from tlz import partition from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -227,3 +228,28 @@ def store( targets, **kwargs, ) + + +@xr.register_dataset_accessor("cubed") +class DatasetAccessor: + def __init__(self, ds): + self.ds = ds + + def visualize( + self, + filename="cubed", + format=None, + optimize_graph=True, + optimize_function=None, + show_hidden=False, + ): + import cubed + + cubed.visualize( + *(self.ds[var].data for var in self.ds.data_vars.keys()), + filename=filename, + format=format, + optimize_graph=optimize_graph, + optimize_function=optimize_function, + show_hidden=show_hidden, + ) diff --git a/cubed_xarray/tests/test_wrapping.py b/cubed_xarray/tests/test_wrapping.py index 3124e18..55f8be8 100644 --- a/cubed_xarray/tests/test_wrapping.py +++ b/cubed_xarray/tests/test_wrapping.py @@ -61,3 +61,14 @@ def test_to_zarr(tmpdir, executor): assert isinstance(restored.var1.data, cubed.Array) computed = restored.compute() assert_allclose(original, computed) + + +def test_dataset_accessor_visualize(tmp_path): + spec = cubed.Spec(allowed_mem="200MB") + + ds = create_test_data().chunk( + chunked_array_type="cubed", from_array_kwargs={"spec": spec} + ) + assert not (tmp_path / "cubed.svg").exists() + ds.cubed.visualize(filename=tmp_path / "cubed") + assert (tmp_path / "cubed.svg").exists() diff --git a/pyproject.toml b/pyproject.toml index 9557751..0439e05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ [project.optional-dependencies] test = [ + "cubed[diagnostics]", "dill", "pre-commit", "ruff",