|
20 | 20 |
|
21 | 21 | DEVICE = os.getenv('DEVICE', 'cuda') |
22 | 22 | SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2') |
23 | | -MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2_hiera_l.yaml') |
24 | | -MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2_hiera_large.pt') |
| 23 | +MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2.1_hiera_l.yaml') |
| 24 | +MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_large.pt') |
25 | 25 | MAX_FRAMES_TO_TRACK = int(os.getenv('MAX_FRAMES_TO_TRACK', 10)) |
26 | 26 |
|
27 | 27 | if DEVICE == 'cuda': |
@@ -73,8 +73,9 @@ def split_frames(self, video_path, temp_dir, start_frame=0, end_frame=100): |
73 | 73 | # Read a frame from the video |
74 | 74 | success, frame = video.read() |
75 | 75 | if frame_count < start_frame: |
| 76 | + frame_count += 1 |
76 | 77 | continue |
77 | | - if frame_count + start_frame >= end_frame: |
| 78 | + if frame_count >= end_frame - 1: |
78 | 79 | break |
79 | 80 |
|
80 | 81 | # If frame is read correctly, success is True |
@@ -217,6 +218,10 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - |
217 | 218 | """ Returns the predicted mask for a smart keypoint that has been placed.""" |
218 | 219 |
|
219 | 220 | from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video') |
| 221 | + |
| 222 | + if not context or not context.get('result'): |
| 223 | + # if there is no context, no interaction has happened yet |
| 224 | + return ModelResponse(predictions=[]) |
220 | 225 |
|
221 | 226 | task = tasks[0] |
222 | 227 | task_id = task['id'] |
@@ -273,22 +278,23 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) - |
273 | 278 |
|
274 | 279 | _, out_obj_ids, out_mask_logits = predictor.add_new_points( |
275 | 280 | inference_state=inference_state, |
276 | | - frame_idx=prompt['frame_idx'], |
| 281 | + frame_idx=prompt['frame_idx'] - first_frame_idx, |
277 | 282 | obj_id=obj_ids[prompt['obj_id']], |
278 | 283 | points=prompt['points'], |
279 | 284 | labels=prompt['labels'] |
280 | 285 | ) |
281 | 286 |
|
282 | 287 | sequence = [] |
283 | 288 |
|
284 | | - debug_dir = './debug-frames' |
285 | | - os.makedirs(debug_dir, exist_ok=True) |
| 289 | + #debug_dir = './debug-frames' |
| 290 | + #os.makedirs(debug_dir, exist_ok=True) |
286 | 291 |
|
287 | 292 | logger.info(f'Propagating in video from frame {last_frame_idx} to {last_frame_idx + frames_to_track}') |
| 293 | + rel_last = last_frame_idx - first_frame_idx |
288 | 294 | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
289 | 295 | inference_state=inference_state, |
290 | | - start_frame_idx=last_frame_idx, |
291 | | - max_frame_num_to_track=frames_to_track |
| 296 | + start_frame_idx=rel_last, |
| 297 | + max_frame_num_to_track=rel_last + frames_to_track |
292 | 298 | ): |
293 | 299 | real_frame_idx = out_frame_idx + first_frame_idx |
294 | 300 | for i, out_obj_id in enumerate(out_obj_ids): |
|
0 commit comments