@@ -32,7 +32,8 @@ def __init__(self):
3232
3333 self .model_path = resolved_path
3434 self .device = self ._get_device ()
35- self .predictor = None
35+ self .feature_predictor = None
36+ self .inference_predictor = None
3637 self .visualizer = Sam3Visualizer ()
3738
3839 logger .info (f"SAM3 (Ultralytics) inference initialized - Model: { self .model_path } , Device: { self .device } " )
@@ -102,13 +103,27 @@ def load_model(self):
102103 task = "segment" ,
103104 mode = "predict" ,
104105 model = self .model_path ,
105- half = True if self . device == "cuda" else False , # Use FP16 for faster inference on CUDA
106+ half = False , # Use full precision as requested
106107 save = False , # Don't save prediction results to disk
107- device = self .device
108+ device = self .device ,
109+ verbose = False # Reduce noise in logs
108110 )
109111
110112 try :
111- self .predictor = SAM3SemanticPredictor (overrides = overrides )
113+ # Initialize two predictors to follow the working reference implementation exactly
114+ # One for feature extraction (image encoding), one for inference (decoding)
115+ self .feature_predictor = SAM3SemanticPredictor (overrides = overrides )
116+ self .inference_predictor = SAM3SemanticPredictor (overrides = overrides )
117+
118+ # Setup both models
119+ self .feature_predictor .setup_model ()
120+ self .inference_predictor .setup_model ()
121+
122+ # Share the underlying model to save VRAM if possible, while keeping separate predictor states
123+ if hasattr (self .feature_predictor , 'model' ) and self .feature_predictor .model is not None :
124+ self .inference_predictor .model = self .feature_predictor .model
125+ logger .info ("Shared underlying model between predictors to optimized VRAM" )
126+
112127 logger .info (f"SAM3 model loaded successfully on { self .device } (FP16: { overrides ['half' ]} )" )
113128 except Exception as e :
114129 logger .error (f"Failed to load SAM3 model: { e } " )
@@ -135,62 +150,109 @@ async def _load_image_from_upload(self, file: UploadFile) -> np.ndarray:
135150 return np .array (image_pil )
136151
137152 async def _run_inference (self , image_np , conf_threshold = None , ** kwargs ) -> tuple :
138- """Shared inference logic following test-yolo-sam3.py pattern.
153+ """Shared inference logic following feature-based inference pattern.
139154
140155 Args:
141156 image_np: Image as numpy array
142157 conf_threshold: Confidence threshold to override model default
143158 **kwargs: Arguments to pass to predictor (text, bboxes, etc.)
144159 """
145160
146- # Update predictor confidence if specified
147- if conf_threshold is not None and hasattr (self .predictor , 'args' ):
148- self .predictor .args .conf = conf_threshold
161+ # Update predictor confidence if specified for inference predictor
162+ if conf_threshold is not None and hasattr (self .inference_predictor , 'args' ):
163+ self .inference_predictor .args .conf = conf_threshold
149164 logger .info (f"Set confidence threshold to { conf_threshold } " )
150165
151- # Set image (like predictor.set_image() in test script)
152- self .predictor .set_image (image_np )
153-
154- # Run prediction (like predictor(text=[...]) in test script)
155- results = self .predictor (** kwargs )
156-
157- # Handle results - Ultralytics returns Results object or list
158- if isinstance (results , list ):
159- result = results [0 ] if len (results ) > 0 else None
166+ # 1. Extract features using feature_predictor (like predictor.set_image())
167+ # Ultralytics models typically expect BGR format when passed numpy arrays directly
168+ # because they are built around OpenCV which uses BGR by default
169+ if image_np .ndim == 3 and image_np .shape [2 ] == 3 :
170+ logger .info ("Converting image from RGB to BGR for Ultralytics predictor" )
171+ image_input = cv2 .cvtColor (image_np , cv2 .COLOR_RGB2BGR )
160172 else :
161- result = results
173+ image_input = image_np
174+
175+ self .feature_predictor .set_image (image_input )
176+ src_shape = image_np .shape [:2 ] # Get image shape (height, width)
162177
163- if result is None :
164- logger .warning ("No results returned from predictor" )
178+ # Verify features are extracted
179+ if not hasattr (self .feature_predictor , 'features' ) or self .feature_predictor .features is None :
180+ logger .error ("Failed to extract features from image" )
181+ return None , None , [], [], []
182+
183+ # 2. Run inference using inference_predictor reusing features (like predictor2.inference_features)
184+ try :
185+ # Extract text or bboxes from kwargs
186+ text_prompts = kwargs .get ('text' , None )
187+ bboxes = kwargs .get ('bboxes' , None )
188+
189+ if text_prompts is not None :
190+ logger .info (f"Running feature-based inference with text prompts: { text_prompts } " )
191+ masks , boxes = self .inference_predictor .inference_features (
192+ self .feature_predictor .features ,
193+ src_shape = src_shape ,
194+ text = text_prompts
195+ )
196+ elif bboxes is not None :
197+ logger .info (f"Running feature-based inference with { len (bboxes )} bounding boxes" )
198+ masks , boxes = self .inference_predictor .inference_features (
199+ self .feature_predictor .features ,
200+ src_shape = src_shape ,
201+ bboxes = bboxes
202+ )
203+ else :
204+ logger .error ("No text prompts or bounding boxes provided" )
205+ return None , None , [], [], []
206+
207+ # Inference result logging
208+ box_count = 0
209+ mask_count = 0
210+
211+ if boxes is not None :
212+ box_count = len (boxes )
213+ if torch .is_tensor (boxes ):
214+ logger .info (f"Boxes tensor shape: { boxes .shape } , device: { boxes .device } " )
215+
216+ if masks is not None :
217+ mask_count = len (masks )
218+ if torch .is_tensor (masks ):
219+ logger .info (f"Masks tensor shape: { masks .shape } , device: { masks .device } " )
220+
221+ logger .info (f"Feature-based inference successful: { mask_count } masks, { box_count } boxes" )
222+
223+ except Exception as e :
224+ logger .error (f"Feature-based inference failed: { e } " , exc_info = True )
165225 return None , None , [], [], []
166226
167227 boxes_list = []
168228 scores_list = []
169229 masks_polygon = []
170230 masks_tensor = None
171231
172- # Extract masks
173- if hasattr ( result , ' masks' ) and result . masks is not None :
174- masks_tensor = result . masks . data # Get mask tensors [N, H, W]
232+ # Process masks
233+ if masks is not None and len ( masks ) > 0 :
234+ masks_tensor = masks # Already tensor format from inference_features
175235 # Convert masks to polygon format
176236 masks_polygon = masks_to_polygon_data (masks_tensor )
177237 logger .info (f"Extracted { len (masks_polygon )} mask(s)" )
178238 else :
179239 logger .warning ("No masks found in results" )
180240
181- # Extract boxes and scores
182- if hasattr (result , 'boxes' ) and result .boxes is not None :
183- boxes_list = result .boxes .xyxy .cpu ().tolist () # [x1, y1, x2, y2] format
184- if hasattr (result .boxes , 'conf' ) and result .boxes .conf is not None :
185- scores_list = result .boxes .conf .cpu ().tolist ()
186- else :
187- # Default confidence if not available
188- scores_list = [1.0 ] * len (boxes_list )
189- logger .info (f"Extracted { len (boxes_list )} box(es) with confidences: { scores_list } " )
241+ # Process boxes - inference_features returns boxes as tensor
242+ if boxes is not None and len (boxes ) > 0 :
243+ boxes_list = boxes .cpu ().tolist () if torch .is_tensor (boxes ) else boxes .tolist ()
244+ # inference_features doesn't return confidence scores directly
245+ # Use default confidence based on threshold
246+ scores_list = [conf_threshold if conf_threshold else settings .SAM3_DEFAULT_THRESHOLD ] * len (boxes_list )
247+ logger .info (f"Extracted { len (boxes_list )} box(es)" )
248+
249+ # Log confidence for each detection
250+ for idx , (box , score ) in enumerate (zip (boxes_list , scores_list )):
251+ logger .info (f" Detection #{ idx + 1 } : bbox={ box } , confidence={ score :.4f} " )
190252 else :
191253 logger .warning ("No boxes found in results" )
192254
193- return result , masks_tensor , boxes_list , scores_list , masks_polygon
255+ return None , masks_tensor , boxes_list , scores_list , masks_polygon
194256
195257
196258 async def inference_text (
0 commit comments