Skip to content

Commit f9cf1e9

Browse files
authored
Update model.py
- Fix #764 - Fix bug: script would not work properly for start_frame > 0, relative indexes were not being calculated.
1 parent 84d698a commit f9cf1e9

File tree

1 file changed

+14
-8
lines changed
  • label_studio_ml/examples/segment_anything_2_video

1 file changed

+14
-8
lines changed

label_studio_ml/examples/segment_anything_2_video/model.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
DEVICE = os.getenv('DEVICE', 'cuda')
2222
SEGMENT_ANYTHING_2_REPO_PATH = os.getenv('SEGMENT_ANYTHING_2_REPO_PATH', 'segment-anything-2')
23-
MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2_hiera_l.yaml')
24-
MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2_hiera_large.pt')
23+
MODEL_CONFIG = os.getenv('MODEL_CONFIG', 'sam2.1_hiera_l.yaml')
24+
MODEL_CHECKPOINT = os.getenv('MODEL_CHECKPOINT', 'sam2.1_hiera_large.pt')
2525
MAX_FRAMES_TO_TRACK = int(os.getenv('MAX_FRAMES_TO_TRACK', 10))
2626

2727
if DEVICE == 'cuda':
@@ -73,8 +73,9 @@ def split_frames(self, video_path, temp_dir, start_frame=0, end_frame=100):
7373
# Read a frame from the video
7474
success, frame = video.read()
7575
if frame_count < start_frame:
76+
frame_count += 1
7677
continue
77-
if frame_count + start_frame >= end_frame:
78+
if frame_count >= end_frame - 1:
7879
break
7980

8081
# If frame is read correctly, success is True
@@ -217,6 +218,10 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
217218
""" Returns the predicted mask for a smart keypoint that has been placed."""
218219

219220
from_name, to_name, value = self.get_first_tag_occurence('VideoRectangle', 'Video')
221+
222+
if not context or not context.get('result'):
223+
# if there is no context, no interaction has happened yet
224+
return ModelResponse(predictions=[])
220225

221226
task = tasks[0]
222227
task_id = task['id']
@@ -273,22 +278,23 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
273278

274279
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
275280
inference_state=inference_state,
276-
frame_idx=prompt['frame_idx'],
281+
frame_idx=prompt['frame_idx'] - first_frame_idx,
277282
obj_id=obj_ids[prompt['obj_id']],
278283
points=prompt['points'],
279284
labels=prompt['labels']
280285
)
281286

282287
sequence = []
283288

284-
debug_dir = './debug-frames'
285-
os.makedirs(debug_dir, exist_ok=True)
289+
#debug_dir = './debug-frames'
290+
#os.makedirs(debug_dir, exist_ok=True)
286291

287292
logger.info(f'Propagating in video from frame {last_frame_idx} to {last_frame_idx + frames_to_track}')
293+
rel_last = last_frame_idx - first_frame_idx
288294
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
289295
inference_state=inference_state,
290-
start_frame_idx=last_frame_idx,
291-
max_frame_num_to_track=frames_to_track
296+
start_frame_idx=rel_last,
297+
max_frame_num_to_track=rel_last + frames_to_track
292298
):
293299
real_frame_idx = out_frame_idx + first_frame_idx
294300
for i, out_obj_id in enumerate(out_obj_ids):

0 commit comments

Comments
 (0)