Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions atomai/models/sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import cv2
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import pandas as pd
import matplotlib.pyplot as plt
import torch
Expand All @@ -26,7 +25,7 @@ class ParticleAnalyzer:
>>> analyzer = ParticleAnalyzer(model_type="vit_h")
>>>
>>> # 2. Load image and run the analysis
>>> image = np.load(path_to_your_image)
>>> image = np.load(IMAGE_PATH)
>>> result = analyzer.analyze(image)
>>>
>>> # 3. Print summary and visualize results
Expand Down Expand Up @@ -96,6 +95,15 @@ def _download_model_if_needed(self, checkpoint_path, model_type):

def _load_model(self, checkpoint_path, model_type):
"""Loads the SAM model from a checkpoint and moves it to the device."""
try:
from segment_anything import sam_model_registry
except ImportError:
raise ImportError(
"The 'segment-anything' package is required to use this feature.\n"
"Please install it directly from the official repository:\n\n"
"pip install git+https://github.com/facebookresearch/segment-anything.git"
)

try:
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=self.device)
Expand Down Expand Up @@ -179,6 +187,15 @@ def _preprocess_image(self, image_array, use_clahe):

def _run_sam(self, image_rgb, preset_name):
"""Initializes and runs the SAM mask generator based on a preset."""
try:
from segment_anything import SamAutomaticMaskGenerator
except ImportError:
raise ImportError(
"The 'segment-anything' package is required to use this feature.\n"
"Please install it directly from the official repository:\n\n"
"pip install git+https://github.com/facebookresearch/segment-anything.git"
)

sam_param_presets = {
"default": {},
"sensitive": {
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
'torchvision>=0.13.0',
'progressbar2>=3.38.0',
'gpytorch>=1.9.1',
'pandas>=1.1.5',
'segment-anything @ git+https://github.com/facebookresearch/segment-anything.git'
'pandas>=1.1.5'
],
classifiers=['Programming Language :: Python',
'Development Status :: 3 - Alpha',
Expand Down