@@ -119,6 +119,7 @@ def INPUT_TYPES(cls):
119119 "mask_blur" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 64 , "step" : 1 }),
120120 "mask_offset" : ("INT" , {"default" : 0 , "min" : - 64 , "max" : 64 , "step" : 1 }),
121121 "invert_output" : ("BOOLEAN" , {"default" : False }),
122+ "unload_model" : ("BOOLEAN" , {"default" : False }),
122123 "background" : (["Alpha" , "Color" ], {"default" : "Alpha" }),
123124 "background_color" : ("COLORCODE" , {"default" : "#222222" }),
124125 },
@@ -190,14 +191,18 @@ def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask
190191 mask_rgb = mask_tensor .reshape ((- 1 , 1 , mask_image .height , mask_image .width )).movedim (1 , - 1 ).expand (- 1 , - 1 , - 1 , 3 )
191192 return result_image , mask_tensor , mask_rgb
192193
193- def segment (self , image , prompt , sam3_model , device , confidence_threshold = 0.5 , mask_blur = 0 , mask_offset = 0 , invert_output = False , background = "Alpha" , background_color = "#222222" ):
194+ def segment (self , image , prompt , sam3_model , device , confidence_threshold = 0.5 , mask_blur = 0 , mask_offset = 0 , invert_output = False , unload_model = False , background = "Alpha" , background_color = "#222222" ):
195+
194196 if image .ndim == 3 :
195197 image = image .unsqueeze (0 )
198+
196199 processor , torch_device = self ._load_processor (sam3_model , device )
197200 autocast_device = comfy .model_management .get_autocast_device (torch_device )
198201 autocast_enabled = torch_device .type == "cuda" and not comfy .model_management .is_device_mps (torch_device )
199202 ctx = torch .autocast (autocast_device , dtype = torch .bfloat16 ) if autocast_enabled else nullcontext ()
203+
200204 result_images , result_masks , result_mask_images = [], [], []
205+
201206 with ctx :
202207 for tensor_img in image :
203208 img_pil , mask_tensor , mask_rgb = self ._run_single (
@@ -214,6 +219,15 @@ def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, m
214219 result_images .append (pil2tensor (img_pil ))
215220 result_masks .append (mask_tensor )
216221 result_mask_images .append (mask_rgb )
222+
223+ if unload_model :
224+ device_str = "cuda" if torch_device .type == "cuda" else "cpu"
225+ cache_key = (sam3_model , device_str )
226+ if cache_key in self .processor_cache :
227+ del self .processor_cache [cache_key ]
228+ if torch_device .type == "cuda" :
229+ torch .cuda .empty_cache ()
230+
217231 return torch .cat (result_images , dim = 0 ), torch .cat (result_masks , dim = 0 ), torch .cat (result_mask_images , dim = 0 )
218232
219233
@@ -223,4 +237,4 @@ def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, m
223237
224238NODE_DISPLAY_NAME_MAPPINGS = {
225239 "SAM3Segment" : "SAM3 Segmentation (RMBG)" ,
226- }
240+ }
0 commit comments