1111matplotlib .use ("Agg" )
1212import matplotlib .pyplot as plt
1313from matplotlib .testing .compare import compare_images
14+ from importlib .resources import files
15+ from importlib .abc import Traversable
1416
1517from neat_ml .bubblesam .bubblesam import (
1618 show_anns ,
2123)
2224from neat_ml .bubblesam .SAM import SAMModel
2325
24- CHECKPOINT = "./ neat_ml/ sam2/ checkpoints/sam2_hiera_large.pt"
26+ CHECKPOINT = files ( " neat_ml. sam2" ). joinpath ( " checkpoints/sam2_hiera_large.pt")
2527
26- def _skip_unless_available (model_chkpt : str = CHECKPOINT ) -> None :
28+ def _skip_unless_available (model_chkpt : Traversable = CHECKPOINT ) -> None :
2729 """
2830 Abort the whole module if we cannot load sam2 or the checkpoint.
2931 """
3032 pytest .importorskip ("sam2" , reason = "sam2 package is required for SAM-2 tests" )
31- if not Path ( model_chkpt ). exists ():
33+ if not model_chkpt . is_file ():
3234 pytest .skip (
3335 f"SAM-2 checkpoint not found at { model_chkpt } . "
3436 "Install it to run integration tests." ,
@@ -42,7 +44,7 @@ def _skip_unless_available(model_chkpt: str = CHECKPOINT) -> None:
4244 reason = "This test is intended for systems without GPU support"
4345)
4446def test_setup_cuda_does_not_crash_on_cpu (
45- model_chkpt : str = CHECKPOINT ,
47+ model_chkpt : Traversable = CHECKPOINT ,
4648):
4749 """
4850 Ensures that calling setup_cuda() in an environment with no GPU
@@ -62,7 +64,7 @@ def test_setup_cuda_does_not_crash_on_cpu(
6264 not torch .cuda .is_available (),
6365 reason = "This test requires a CUDA-enabled GPU"
6466)
65- def test_setup_cuda_on_real_gpu (model_chkpt = CHECKPOINT ):
67+ def test_setup_cuda_on_real_gpu (model_chkpt : Traversable = CHECKPOINT ):
6668 """
6769 Verifies that setup_cuda() correctly configures torch backends on
6870 a live GPU. This test only runs if a CUDA device is found.
@@ -81,7 +83,7 @@ def test_setup_cuda_on_real_gpu(model_chkpt = CHECKPOINT):
8183 assert torch .backends .cudnn .allow_tf32
8284
8385@pytest .fixture (scope = "module" )
84- def real_sam_model (model_chkpt : str = CHECKPOINT ) -> SAMModel :
86+ def real_sam_model (model_chkpt : Traversable = CHECKPOINT ) -> SAMModel :
8587 """
8688 Actual SAM-2 network on CPU
8789 """
0 commit comments