|
| 1 | +import os |
| 2 | +import sys |
| 3 | +from contextlib import nullcontext |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from PIL import Image, ImageFilter |
| 8 | +from torch.hub import download_url_to_file |
| 9 | + |
| 10 | +import folder_paths |
| 11 | +import comfy.model_management |
| 12 | + |
| 13 | +from AILab_ImageMaskTools import pil2tensor, tensor2pil |
| 14 | + |
| 15 | +CURRENT_DIR = os.path.dirname(__file__) |
| 16 | +SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3") |
| 17 | +if SAM3_LOCAL_DIR not in sys.path: |
| 18 | + sys.path.insert(0, SAM3_LOCAL_DIR) |
| 19 | + |
| 20 | +SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz") |
| 21 | +if not os.path.isfile(SAM3_BPE_PATH): |
| 22 | + raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.") |
| 23 | + |
| 24 | +from sam3.model_builder import build_sam3_image_model # noqa: E402 |
| 25 | +from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402 |
| 26 | + |
| 27 | +_DEFAULT_PT_ENTRY = { |
| 28 | + "model_url": "https://huggingface.co/1038lab/sam3/resolve/main/sam3.pt", |
| 29 | + "filename": "sam3.pt", |
| 30 | +} |
| 31 | + |
| 32 | +SAM3_MODELS = { |
| 33 | + "sam3": _DEFAULT_PT_ENTRY.copy(), |
| 34 | +} |
| 35 | + |
| 36 | + |
| 37 | +def get_sam3_pt_models(): |
| 38 | + """Return a dictionary containing the PT model definition.""" |
| 39 | + entry = SAM3_MODELS.get("sam3") |
| 40 | + if entry and entry.get("filename", "").endswith(".pt"): |
| 41 | + return {"sam3": entry} |
| 42 | + # Fallback: upgrade any legacy entry to PT naming |
| 43 | + for key, value in SAM3_MODELS.items(): |
| 44 | + if value.get("filename", "").endswith(".pt"): |
| 45 | + return {"sam3": value} |
| 46 | + if "sam3" in key and value: |
| 47 | + candidate = value.copy() |
| 48 | + candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"] |
| 49 | + candidate["filename"] = _DEFAULT_PT_ENTRY["filename"] |
| 50 | + return {"sam3": candidate} |
| 51 | + return {"sam3": _DEFAULT_PT_ENTRY.copy()} |
| 52 | + |
| 53 | + |
| 54 | +def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0): |
| 55 | + if invert_output: |
| 56 | + mask_np = np.array(mask_image) |
| 57 | + mask_image = Image.fromarray(255 - mask_np) |
| 58 | + if mask_blur > 0: |
| 59 | + mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur)) |
| 60 | + if mask_offset != 0: |
| 61 | + filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter |
| 62 | + size = abs(mask_offset) * 2 + 1 |
| 63 | + for _ in range(abs(mask_offset)): |
| 64 | + mask_image = mask_image.filter(filt(size)) |
| 65 | + return mask_image |
| 66 | + |
| 67 | + |
| 68 | +def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"): |
| 69 | + rgba_image = image.copy().convert("RGBA") |
| 70 | + rgba_image.putalpha(mask_image.convert("L")) |
| 71 | + if background == "Color": |
| 72 | + hex_color = background_color.lstrip("#") |
| 73 | + r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16) |
| 74 | + bg_image = Image.new("RGBA", image.size, (r, g, b, 255)) |
| 75 | + composite = Image.alpha_composite(bg_image, rgba_image) |
| 76 | + return composite.convert("RGB") |
| 77 | + return rgba_image |
| 78 | + |
| 79 | + |
| 80 | +def get_or_download_model_file(filename, url): |
| 81 | + local_path = None |
| 82 | + if hasattr(folder_paths, "get_full_path"): |
| 83 | + local_path = folder_paths.get_full_path("sam3", filename) |
| 84 | + if local_path and os.path.isfile(local_path): |
| 85 | + return local_path |
| 86 | + base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models")) |
| 87 | + models_dir = os.path.join(base_models_dir, "sam3") |
| 88 | + os.makedirs(models_dir, exist_ok=True) |
| 89 | + local_path = os.path.join(models_dir, filename) |
| 90 | + if not os.path.exists(local_path): |
| 91 | + print(f"Downloading {filename} from {url} ...") |
| 92 | + download_url_to_file(url, local_path) |
| 93 | + return local_path |
| 94 | + |
| 95 | + |
| 96 | +def _resolve_device(user_choice): |
| 97 | + auto_device = comfy.model_management.get_torch_device() |
| 98 | + if user_choice == "CPU": |
| 99 | + return torch.device("cpu") |
| 100 | + if user_choice == "GPU": |
| 101 | + if auto_device.type != "cuda": |
| 102 | + raise RuntimeError("GPU unavailable") |
| 103 | + return torch.device("cuda") |
| 104 | + return auto_device |
| 105 | + |
| 106 | + |
| 107 | +class SAM3Segment: |
| 108 | + @classmethod |
| 109 | + def INPUT_TYPES(cls): |
| 110 | + return { |
| 111 | + "required": { |
| 112 | + "image": ("IMAGE",), |
| 113 | + "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}), |
| 114 | + "sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}), |
| 115 | + "device": (["Auto", "CPU", "GPU"], {"default": "Auto"}), |
| 116 | + "confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}), |
| 117 | + }, |
| 118 | + "optional": { |
| 119 | + "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}), |
| 120 | + "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}), |
| 121 | + "invert_output": ("BOOLEAN", {"default": False}), |
| 122 | + "background": (["Alpha", "Color"], {"default": "Alpha"}), |
| 123 | + "background_color": ("COLORCODE", {"default": "#222222"}), |
| 124 | + }, |
| 125 | + } |
| 126 | + |
| 127 | + RETURN_TYPES = ("IMAGE", "MASK", "IMAGE") |
| 128 | + RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE") |
| 129 | + FUNCTION = "segment" |
| 130 | + CATEGORY = "🧪AILab/🧽RMBG" |
| 131 | + |
| 132 | + def __init__(self): |
| 133 | + self.processor_cache = {} |
| 134 | + |
| 135 | + def _load_processor(self, model_choice, device_choice): |
| 136 | + torch_device = _resolve_device(device_choice) |
| 137 | + device_str = "cuda" if torch_device.type == "cuda" else "cpu" |
| 138 | + cache_key = (model_choice, device_str) |
| 139 | + if cache_key not in self.processor_cache: |
| 140 | + model_info = SAM3_MODELS[model_choice] |
| 141 | + ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"]) |
| 142 | + model = build_sam3_image_model( |
| 143 | + bpe_path=SAM3_BPE_PATH, |
| 144 | + device=device_str, |
| 145 | + eval_mode=True, |
| 146 | + checkpoint_path=ckpt_path, |
| 147 | + load_from_HF=False, |
| 148 | + enable_segmentation=True, |
| 149 | + enable_inst_interactivity=False, |
| 150 | + ) |
| 151 | + processor = Sam3Processor(model, device=device_str) |
| 152 | + self.processor_cache[cache_key] = processor |
| 153 | + return self.processor_cache[cache_key], torch_device |
| 154 | + |
| 155 | + def _empty_result(self, img_pil, background, background_color): |
| 156 | + w, h = img_pil.size |
| 157 | + mask_image = Image.new("L", (w, h), 0) |
| 158 | + result_image = apply_background_color(img_pil, mask_image, background, background_color) |
| 159 | + if background == "Alpha": |
| 160 | + result_image = result_image.convert("RGBA") |
| 161 | + else: |
| 162 | + result_image = result_image.convert("RGB") |
| 163 | + empty_mask = torch.zeros((1, h, w), dtype=torch.float32) |
| 164 | + mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3) |
| 165 | + return result_image, empty_mask, mask_rgb |
| 166 | + |
| 167 | + def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color): |
| 168 | + img_pil = tensor2pil(img_tensor) |
| 169 | + text = prompt.strip() or "object" |
| 170 | + state = processor.set_image(img_pil) |
| 171 | + processor.reset_all_prompts(state) |
| 172 | + processor.set_confidence_threshold(confidence, state) |
| 173 | + state = processor.set_text_prompt(text, state) |
| 174 | + masks = state.get("masks") |
| 175 | + if masks is None or masks.numel() == 0: |
| 176 | + return self._empty_result(img_pil, background, background_color) |
| 177 | + masks = masks.float().to("cpu") |
| 178 | + if masks.ndim == 4: |
| 179 | + masks = masks.squeeze(1) |
| 180 | + combined = masks.amax(dim=0) |
| 181 | + mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8) |
| 182 | + mask_image = Image.fromarray(mask_np, mode="L") |
| 183 | + mask_image = process_mask(mask_image, invert, mask_blur, mask_offset) |
| 184 | + result_image = apply_background_color(img_pil, mask_image, background, background_color) |
| 185 | + if background == "Alpha": |
| 186 | + result_image = result_image.convert("RGBA") |
| 187 | + else: |
| 188 | + result_image = result_image.convert("RGB") |
| 189 | + mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0) |
| 190 | + mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3) |
| 191 | + return result_image, mask_tensor, mask_rgb |
| 192 | + |
| 193 | + def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, background="Alpha", background_color="#222222"): |
| 194 | + if image.ndim == 3: |
| 195 | + image = image.unsqueeze(0) |
| 196 | + processor, torch_device = self._load_processor(sam3_model, device) |
| 197 | + autocast_device = comfy.model_management.get_autocast_device(torch_device) |
| 198 | + autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device) |
| 199 | + ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext() |
| 200 | + result_images, result_masks, result_mask_images = [], [], [] |
| 201 | + with ctx: |
| 202 | + for tensor_img in image: |
| 203 | + img_pil, mask_tensor, mask_rgb = self._run_single( |
| 204 | + processor, |
| 205 | + tensor_img, |
| 206 | + prompt, |
| 207 | + confidence_threshold, |
| 208 | + mask_blur, |
| 209 | + mask_offset, |
| 210 | + invert_output, |
| 211 | + background, |
| 212 | + background_color, |
| 213 | + ) |
| 214 | + result_images.append(pil2tensor(img_pil)) |
| 215 | + result_masks.append(mask_tensor) |
| 216 | + result_mask_images.append(mask_rgb) |
| 217 | + return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0) |
| 218 | + |
| 219 | + |
| 220 | +NODE_CLASS_MAPPINGS = { |
| 221 | + "SAM3Segment": SAM3Segment, |
| 222 | +} |
| 223 | + |
| 224 | +NODE_DISPLAY_NAME_MAPPINGS = { |
| 225 | + "SAM3Segment": "SAM3 Segmentation (RMBG)", |
| 226 | +} |
0 commit comments