Skip to content

Commit 3661bec

Browse files
committed
CI: download sam2 checkpoints for tests
1 parent 472535f commit 3661bec

File tree

6 files changed

+30
-13
lines changed

6 files changed

+30
-13
lines changed

.github/workflows/ci.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,21 @@ jobs:
2626
- name: install and lint
2727
run: |
2828
python -m pip install -e ./neat_ml/sam2/
29-
python -m pip install -v ".[dev]"
29+
python -m pip install -v -e ".[dev]"
3030
mypy neat_ml
3131
ruff check neat_ml
32+
- name: Cache SAM2 checkpoint
33+
id: cache-sam2
34+
uses: actions/cache@v4
35+
with:
36+
path: neat_ml/sam2/checkpoints/sam2_hiera_large.pt
37+
key: sam2-hiera-large-checkpoint
38+
- name: Download SAM2 checkpoint
39+
if: steps.cache-sam2.outputs.cache-hit != 'true'
40+
run: |
41+
mkdir -p neat_ml/sam2/checkpoints
42+
wget -q https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt \
43+
-O neat_ml/sam2/checkpoints/sam2_hiera_large.pt
3244
- name: install fonts
3345
if: runner.os == 'Linux'
3446
run: |

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pip install -e ".[notebooks]"
1818
cd checkpoints
1919
# modify the provided bash script to download the appropriate checkpoints
2020
git apply ../../checkpoint.diff
21-
sh download_chkpts.sh
21+
sh download_ckpts.sh
2222
```
2323

2424
NOTE:

neat_ml/bubblesam/SAM.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
8+
from importlib.abc import Traversable
89

910
from sam2.build_sam import build_sam2
1011
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
@@ -21,7 +22,7 @@ class SAMModel:
2122
def __init__(
2223
self,
2324
model_config: str,
24-
checkpoint_path: str,
25+
checkpoint_path: Traversable,
2526
device: str
2627
) -> None:
2728
"""
@@ -33,7 +34,7 @@ def __init__(
3334
----------
3435
model_config : str
3536
YAML cfg describing network architecture.
36-
checkpoint_path : str
37+
checkpoint_path : Path
3738
Path to *.pt checkpoint with learned weights.
3839
device : str
3940
Torch device ('cuda' | 'cpu' | 'cuda:0', …).

neat_ml/bubblesam/bubblesam.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
from matplotlib.axes import Axes
1212
from numpy.random import Generator
13+
from importlib.resources import files
1314

1415
from skimage.measure import label, regionprops, find_contours
1516
from matplotlib.patches import Rectangle
@@ -18,7 +19,6 @@
1819

1920
memory = joblib.Memory("joblib_cache", verbose=0)
2021

21-
logging.basicConfig(level=logging.INFO)
2222
logger = logging.getLogger(__name__)
2323

2424
# these settings are used to parametrize
@@ -44,9 +44,11 @@
4444
"use_m2m": True,
4545
}
4646

47+
checkpoint_path = files("neat_ml.sam2").joinpath("checkpoints/sam2_hiera_large.pt")
48+
4749
DEFAULT_MODEL_CFG = {
4850
"model_config": "sam2_hiera_l.yaml",
49-
"checkpoint_path": "./neat_ml/sam2/checkpoints/sam2_hiera_large.pt",
51+
"checkpoint_path": checkpoint_path,
5052
"device": ("cuda" if torch.cuda.is_available()
5153
else "mps" if torch.backends.mps.is_available()
5254
else "cpu"

neat_ml/tests/test_bubblesam.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
matplotlib.use("Agg")
1212
import matplotlib.pyplot as plt
1313
from matplotlib.testing.compare import compare_images
14+
from importlib.resources import files
15+
from importlib.abc import Traversable
1416

1517
from neat_ml.bubblesam.bubblesam import (
1618
show_anns,
@@ -21,14 +23,14 @@
2123
)
2224
from neat_ml.bubblesam.SAM import SAMModel
2325

24-
CHECKPOINT = "./neat_ml/sam2/checkpoints/sam2_hiera_large.pt"
26+
CHECKPOINT = files("neat_ml.sam2").joinpath("checkpoints/sam2_hiera_large.pt")
2527

26-
def _skip_unless_available(model_chkpt: str = CHECKPOINT) -> None:
28+
def _skip_unless_available(model_chkpt: Traversable = CHECKPOINT) -> None:
2729
"""
2830
Abort the whole module if we cannot load sam2 or the checkpoint.
2931
"""
3032
pytest.importorskip("sam2", reason="sam2 package is required for SAM-2 tests")
31-
if not Path(model_chkpt).exists():
33+
if not model_chkpt.is_file():
3234
pytest.skip(
3335
f"SAM-2 checkpoint not found at {model_chkpt}. "
3436
"Install it to run integration tests.",
@@ -42,7 +44,7 @@ def _skip_unless_available(model_chkpt: str = CHECKPOINT) -> None:
4244
reason="This test is intended for systems without GPU support"
4345
)
4446
def test_setup_cuda_does_not_crash_on_cpu(
45-
model_chkpt: str = CHECKPOINT,
47+
model_chkpt: Traversable = CHECKPOINT,
4648
):
4749
"""
4850
Ensures that calling setup_cuda() in an environment with no GPU
@@ -62,7 +64,7 @@ def test_setup_cuda_does_not_crash_on_cpu(
6264
not torch.cuda.is_available(),
6365
reason="This test requires a CUDA-enabled GPU"
6466
)
65-
def test_setup_cuda_on_real_gpu(model_chkpt = CHECKPOINT):
67+
def test_setup_cuda_on_real_gpu(model_chkpt: Traversable = CHECKPOINT):
6668
"""
6769
Verifies that setup_cuda() correctly configures torch backends on
6870
a live GPU. This test only runs if a CUDA device is found.
@@ -81,7 +83,7 @@ def test_setup_cuda_on_real_gpu(model_chkpt = CHECKPOINT):
8183
assert torch.backends.cudnn.allow_tf32
8284

8385
@pytest.fixture(scope="module")
84-
def real_sam_model(model_chkpt: str = CHECKPOINT) -> SAMModel:
86+
def real_sam_model(model_chkpt: Traversable = CHECKPOINT) -> SAMModel:
8587
"""
8688
Actual SAM-2 network on CPU
8789
"""

neat_ml/tests/test_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_stage_detect_pipeline_runs(
179179
chkpt_dir.mkdir(parents=True)
180180
shutil.copy(
181181
files(neat_ml) / "sam2/checkpoints/sam2_hiera_large.pt",
182-
Path(tmpdir) / chkpt_dir,
182+
chkpt_dir,
183183
) # type: ignore[call-overload]
184184

185185
ds["detection"]["img_dir"] = img_dir

0 commit comments

Comments
 (0)