Skip to content

Commit 0f8b0cc

Browse files
authored
Make matplotlib dependency optional (#1131)
* Make matplotlib dependency optional * Add dep for docs * Improve testing. * Auto-format.
1 parent eb1e647 commit 0f8b0cc

File tree

13 files changed

+105
-64
lines changed

13 files changed

+105
-64
lines changed

.github/workflows/smoke-test.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,23 @@ jobs:
3030
- name: Install dependencies
3131
run: |
3232
sudo apt-get update
33-
uv pip install --system -e .[dev,full]
33+
uv pip install --system -e .[dev]
3434
if [ -f requirements.txt ]; then uv pip install --system -r requirements.txt; fi
3535
- name: List dependencies
3636
run: |
3737
pip list
3838
- name: Run unit tests with pytest
3939
run: |
4040
python -m pytest
41+
- name: Install FULL dependencies
42+
run: |
43+
uv pip install --system -e .[full]
44+
- name: List FULL dependencies
45+
run: |
46+
pip list
47+
- name: Run FULL unit tests with pytest
48+
run: |
49+
python -m pytest
4150
- name: Send status to Slack app
4251
if: ${{ failure() && github.event_name != 'workflow_dispatch' }}
4352
id: slack

.setup_dev.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ python -m pip install -e . > /dev/null
4141

4242
echo "Installing developer dependencies in local environment"
4343
python -m pip install -e .'[dev]' > /dev/null
44+
python -m pip install -e .'[full]' > /dev/null
45+
if [ -f requirements.txt ]; then python -m pip install -r requirements.txt > /dev/null; fi
4446
if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt > /dev/null; fi
4547

4648
echo "Installing pre-commit"

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ ipython
44
ipywidgets
55
jupytext
66
lsst-sphgeom
7+
matplotlib
78
nbconvert
89
nbsphinx
910
sphinx

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ full = [
5252
"ipykernel", # Support for Jupyter notebooks
5353
"ipywidgets", # useful for tqdm in notebooks.
5454
"lsst-sphgeom ; sys_platform == 'darwin' or sys_platform == 'linux'", # To handle spherical sky polygons, not available on Windows
55+
"matplotlib",
5556
]
5657

5758
[build-system]
@@ -65,9 +66,6 @@ build-backend = "setuptools.build_meta"
6566
write_to = "src/lsdb/_version.py"
6667

6768
[tool.pytest.ini_options]
68-
markers = [
69-
"sphgeom: mark tests as having a runtime dependency on lsst-sphgeom",
70-
]
7169
testpaths = [
7270
"tests",
7371
"src",

src/lsdb/catalog/dataset/healpix_dataset.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from collections.abc import Sequence
77
from pathlib import Path
8-
from typing import Callable, Iterable, Type
8+
from typing import TYPE_CHECKING, Callable, Iterable, Type
99

1010
import astropy
1111
import dask
@@ -15,15 +15,11 @@
1515
import pandas as pd
1616
from astropy.coordinates import SkyCoord
1717
from astropy.units import Quantity
18-
from astropy.visualization.wcsaxes import WCSAxes
19-
from astropy.visualization.wcsaxes.frame import BaseFrame
2018
from dask.dataframe.core import _repr_data_series
2119
from deprecated import deprecated # type: ignore
2220
from hats.catalog.healpix_dataset.healpix_dataset import HealpixDataset as HCHealpixDataset
23-
from hats.inspection.visualize_catalog import get_fov_moc_from_wcs, initialize_wcs_axes
2421
from hats.pixel_math import HealpixPixel
2522
from hats.pixel_math.healpix_pixel_function import get_pixel_argsort
26-
from matplotlib.figure import Figure
2723
from mocpy import MOC
2824
from pandas._typing import Renamer
2925
from typing_extensions import Self
@@ -48,6 +44,11 @@
4844
from lsdb.loaders.hats.hats_loading_config import HatsLoadingConfig
4945
from lsdb.types import DaskDFPixelMap
5046

47+
if TYPE_CHECKING:
48+
from astropy.visualization.wcsaxes import WCSAxes
49+
from astropy.visualization.wcsaxes.frame import BaseFrame
50+
from matplotlib.figure import Figure
51+
5152

5253
# pylint: disable=protected-access,too-many-public-methods,too-many-lines,import-outside-toplevel,cyclic-import
5354
class HealpixDataset(Dataset):
@@ -1376,7 +1377,16 @@ def plot_points(
13761377
tuple[Figure, WCSAxes]
13771378
The figure and axes used for the plot
13781379
"""
1379-
fig, ax, wcs = initialize_wcs_axes(
1380+
try:
1381+
# pylint: disable=import-outside-toplevel
1382+
from hats.inspection._plotting import _get_fov_moc_from_wcs, _initialize_wcs_axes
1383+
from matplotlib import pyplot as plt # pylint: disable=unused-import
1384+
except ImportError as exc:
1385+
raise ImportError(
1386+
"matplotlib is required to use this method. Install with pip or conda."
1387+
) from exc
1388+
1389+
fig, ax, wcs = _initialize_wcs_axes(
13801390
projection=projection,
13811391
fov=fov,
13821392
center=center,
@@ -1387,7 +1397,7 @@ def plot_points(
13871397
figsize=(9, 5),
13881398
)
13891399

1390-
fov_moc = get_fov_moc_from_wcs(wcs)
1400+
fov_moc = _get_fov_moc_from_wcs(wcs)
13911401

13921402
computed_catalog = (
13931403
self.search(MOCSearch(fov_moc)).compute() if fov_moc is not None else self.compute()

src/lsdb/core/plotting/plot_points.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from __future__ import annotations
22

3-
from typing import Type
3+
from typing import TYPE_CHECKING, Type
44

55
import astropy
6-
import matplotlib.pyplot as plt
76
import pandas as pd
87
from astropy.coordinates import SkyCoord
98
from astropy.units import Quantity
10-
from astropy.visualization.wcsaxes import WCSAxes
11-
from astropy.visualization.wcsaxes.frame import BaseFrame
12-
from hats.inspection.visualize_catalog import initialize_wcs_axes
13-
from matplotlib.figure import Figure
14-
from mocpy.moc.plot.utils import _set_wcs
9+
10+
if TYPE_CHECKING:
11+
from astropy.visualization.wcsaxes import WCSAxes
12+
from astropy.visualization.wcsaxes.frame import BaseFrame
13+
from matplotlib.figure import Figure
1514

1615

1716
def plot_points(
@@ -81,7 +80,16 @@ def plot_points(
8180
tuple[Figure, WCSAxes]
8281
The figure and axes used for the plot
8382
"""
84-
fig, ax, wcs = initialize_wcs_axes(
83+
84+
try:
85+
# pylint: disable=import-outside-toplevel
86+
from hats.inspection.visualize_catalog import _initialize_wcs_axes
87+
from matplotlib import pyplot as plt
88+
from mocpy.moc.plot.utils import _set_wcs
89+
except ImportError as exc:
90+
raise ImportError("matplotlib is required to use this method. Install with pip or conda.") from exc
91+
92+
fig, ax, wcs = _initialize_wcs_axes(
8593
projection=projection,
8694
fov=fov,
8795
center=center,

src/lsdb/core/search/abstract_search.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
import nested_pandas as npd
88
from astropy.coordinates import SkyCoord
99
from astropy.units import Quantity
10-
from astropy.visualization.wcsaxes import WCSAxes
11-
from astropy.visualization.wcsaxes.frame import BaseFrame
1210
from hats.catalog import TableProperties
13-
from hats.inspection.visualize_catalog import initialize_wcs_axes
14-
from matplotlib import pyplot as plt
15-
from matplotlib.figure import Figure
1611

1712
if TYPE_CHECKING:
13+
from astropy.visualization.wcsaxes import WCSAxes
14+
from astropy.visualization.wcsaxes.frame import BaseFrame
15+
from matplotlib.figure import Figure
16+
1817
from lsdb.types import HCCatalogTypeVar
1918

2019

@@ -135,7 +134,16 @@ def plot(
135134
tuple[Figure, WCSAxes]
136135
The figure and axes used for the plot
137136
"""
138-
fig, ax, wcs = initialize_wcs_axes(
137+
try:
138+
# pylint: disable=import-outside-toplevel
139+
from hats.inspection._plotting import _initialize_wcs_axes
140+
from matplotlib import pyplot as plt
141+
except ImportError as exc:
142+
raise ImportError(
143+
"matplotlib is required to use this method. Install with pip or conda."
144+
) from exc
145+
146+
fig, ax, wcs = _initialize_wcs_axes(
139147
projection=projection,
140148
fov=fov,
141149
center=center,

src/lsdb/core/search/region_search.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING
4+
35
import astropy.units as u
46
import nested_pandas as npd
57
import pandas as pd
6-
from astropy.visualization.wcsaxes import SphericalCircle, WCSAxes
78
from hats.catalog import TableProperties
89
from hats.pixel_math import HealpixPixel, get_healpix_pixel, spatial_index
910
from hats.pixel_math.region_to_moc import wrap_ra_angles
@@ -24,6 +25,9 @@
2425
from lsdb.core.search.abstract_search import AbstractSearch
2526
from lsdb.types import HCCatalogTypeVar
2627

28+
if TYPE_CHECKING:
29+
from astropy.visualization.wcsaxes import WCSAxes
30+
2731

2832
class BoxSearch(AbstractSearch):
2933
"""Perform a box search to filter the catalog. This type of search is used for a
@@ -73,6 +77,14 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np
7377
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata)
7478

7579
def _perform_plot(self, ax: WCSAxes, **kwargs):
80+
try:
81+
# pylint: disable=import-outside-toplevel
82+
from astropy.visualization.wcsaxes import SphericalCircle
83+
except ImportError as exc:
84+
raise ImportError(
85+
"matplotlib is required to use this method. Install with pip or conda."
86+
) from exc
87+
7688
kwargs_to_use = {"ec": "tab:red", "fc": "none"}
7789
kwargs_to_use.update(kwargs)
7890

@@ -194,6 +206,13 @@ class PolygonSearch(AbstractSearch):
194206
"""
195207

196208
def __init__(self, vertices: list[tuple[float, float]], fine: bool = True):
209+
try:
210+
# pylint: disable=unused-import,import-outside-toplevel
211+
from lsst.sphgeom import ConvexPolygon
212+
except ImportError as exc:
213+
raise ImportError(
214+
"lsst-sphgeom is required to use this method. Install with pip or conda."
215+
) from exc
197216
super().__init__(fine)
198217
validate_polygon(vertices)
199218
self.vertices = vertices

tests/conftest.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -396,27 +396,6 @@ def cone_search_margin_expected(cone_search_expected_dir):
396396
return pd.read_csv(cone_search_expected_dir / "margin.csv", index_col=SPATIAL_INDEX_COLUMN)
397397

398398

399-
# pylint: disable=import-outside-toplevel
400-
def pytest_collection_modifyitems(items):
401-
"""Modify tests that use the `lsst-sphgeom` package to only run when that
402-
package has been installed in the development environment.
403-
404-
If we detect that we can import `lsst-sphgeom`, this method exits early
405-
and does not modify any test items.
406-
"""
407-
try:
408-
# pylint: disable=unused-import
409-
from lsst.sphgeom import ConvexPolygon
410-
411-
return
412-
except ImportError:
413-
pass
414-
415-
for item in items:
416-
if any(item.iter_markers(name="sphgeom")):
417-
item.add_marker(pytest.mark.skip(reason="lsst-sphgeom is not installed"))
418-
419-
420399
class Helpers:
421400
@staticmethod
422401
def assert_divisions_are_correct(catalog):

tests/lsdb/catalog/test_catalog.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
import dask.dataframe as dd
66
import hats as hc
77
import hats.pixel_math.healpix_shim as hp
8-
import matplotlib as mpl
9-
import matplotlib.pyplot as plt
108
import nested_pandas as npd
119
import numpy as np
1210
import numpy.testing as npt
1311
import pandas as pd
1412
import pytest
1513
from astropy.coordinates import SkyCoord
1614
from astropy.visualization.wcsaxes import WCSAxes
17-
from hats.inspection.visualize_catalog import get_fov_moc_from_wcs
15+
from hats.inspection._plotting import _get_fov_moc_from_wcs
1816
from hats.pixel_math import HealpixPixel
1917
from mocpy import WCS
2018
from nested_pandas.datasets import generate_data
@@ -24,12 +22,14 @@
2422
from lsdb import Catalog, MarginCatalog
2523
from lsdb.core.search.region_search import MOCSearch
2624

27-
mpl.use("Agg")
28-
2925

3026
@pytest.fixture(autouse=True)
3127
def reset_matplotlib():
3228
yield
29+
mpl = pytest.importorskip("matplotlib")
30+
plt = pytest.importorskip("matplotlib.pyplot")
31+
32+
mpl.use("Agg")
3333
plt.close("all")
3434

3535

@@ -667,9 +667,9 @@ def test_filtered_catalog_has_undetermined_len(small_sky_order1_catalog, small_s
667667
len(small_sky_order1_catalog.pixel_search([(0, 11)]))
668668

669669

670-
@pytest.mark.sphgeom
671670
def test_filtered_catalog_has_undetermined_len_polygon(small_sky_order1_catalog):
672671
"""Tests that filtered catalogs have an undetermined number of rows"""
672+
pytest.importorskip("lsst.sphgeom")
673673
with pytest.raises(ValueError, match="undetermined"):
674674
vertices = [(300, -50), (300, -55), (272, -55), (272, -50)]
675675
len(small_sky_order1_catalog.polygon_search(vertices))
@@ -736,12 +736,13 @@ def test_plot_points(small_sky_order1_catalog, mocker):
736736

737737

738738
def test_plot_points_fov(small_sky_order1_catalog, mocker):
739+
plt = pytest.importorskip("matplotlib.pyplot")
739740
mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter")
740741
fig = plt.figure(figsize=(10, 6))
741742
center = SkyCoord(350, -80, unit="deg")
742743
fov = 10 * u.deg
743744
wcs = WCS(fig=fig, fov=fov, center=center, projection="MOL").w
744-
wcs_moc = get_fov_moc_from_wcs(wcs)
745+
wcs_moc = _get_fov_moc_from_wcs(wcs)
745746
_, ax = small_sky_order1_catalog.plot_points(fov=fov, center=center)
746747
comp_cat = small_sky_order1_catalog.search(MOCSearch(wcs_moc)).compute()
747748
WCSAxes.scatter.assert_called_once()
@@ -751,12 +752,13 @@ def test_plot_points_fov(small_sky_order1_catalog, mocker):
751752

752753

753754
def test_plot_points_wcs(small_sky_order1_catalog, mocker):
755+
plt = pytest.importorskip("matplotlib.pyplot")
754756
mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter")
755757
fig = plt.figure(figsize=(10, 6))
756758
center = SkyCoord(350, -80, unit="deg")
757759
fov = 10 * u.deg
758760
wcs = WCS(fig=fig, fov=fov, center=center).w
759-
wcs_moc = get_fov_moc_from_wcs(wcs)
761+
wcs_moc = _get_fov_moc_from_wcs(wcs)
760762
_, ax = small_sky_order1_catalog.plot_points(wcs=wcs)
761763
comp_cat = small_sky_order1_catalog.search(MOCSearch(wcs_moc)).compute()
762764
WCSAxes.scatter.assert_called_once()
@@ -766,6 +768,7 @@ def test_plot_points_wcs(small_sky_order1_catalog, mocker):
766768

767769

768770
def test_plot_points_colorcol(small_sky_order1_catalog, mocker):
771+
plt = pytest.importorskip("matplotlib.pyplot")
769772
mocker.patch("astropy.visualization.wcsaxes.WCSAxes.scatter")
770773
mocker.patch("matplotlib.pyplot.colorbar")
771774
_, ax = small_sky_order1_catalog.plot_points(color_col="id")

0 commit comments

Comments
 (0)