Skip to content

Commit 0236d5b

Browse files
author
Alessandro Marzo
committed
fix: make SAM 2 work in case data value is a list of images
1 parent e672ccf commit 0236d5b

File tree

1 file changed

+16
-4
lines changed
  • label_studio_ml/examples/segment_anything_2_image

1 file changed

+16
-4
lines changed

label_studio_ml/examples/segment_anything_2_image/model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class NewModel(LabelStudioMLBase):
4343
"""Custom ML Backend model
4444
"""
4545

46-
def get_results(self, masks, probs, width, height, from_name, to_name, label):
46+
def get_results(self, masks, probs, width, height, from_name, to_name, label, item_index):
4747
results = []
4848
total_prob = 0
4949
for mask, prob in zip(masks, probs):
@@ -53,7 +53,7 @@ def get_results(self, masks, probs, width, height, from_name, to_name, label):
5353
mask = mask * 255
5454
rle = brush.mask2rle(mask)
5555
total_prob += prob
56-
results.append({
56+
annotation_result = {
5757
'id': label_id,
5858
'from_name': from_name,
5959
'to_name': to_name,
@@ -68,7 +68,13 @@ def get_results(self, masks, probs, width, height, from_name, to_name, label):
6868
'score': prob,
6969
'type': 'brushlabels',
7070
'readonly': False
71-
})
71+
}
72+
73+
74+
if item_index is not None:
75+
annotation_result['item_index'] = item_index
76+
77+
results.append(annotation_result)
7278

7379
return [{
7480
'result': results,
@@ -139,6 +145,11 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
139145
print(f'Point coords are {point_coords}, point labels are {point_labels}, input box is {input_box}')
140146

141147
img_url = tasks[0]['data'][value]
148+
if isinstance(img_url, list):
149+
item_index = context['result'][0]['item_index']
150+
img_url = img_url[item_index]
151+
else:
152+
item_index = None
142153
predictor_results = self._sam_predict(
143154
img_url=img_url,
144155
point_coords=point_coords or None,
@@ -154,6 +165,7 @@ def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -
154165
height=image_height,
155166
from_name=from_name,
156167
to_name=to_name,
157-
label=selected_label)
168+
label=selected_label,
169+
item_index=item_index)
158170

159171
return ModelResponse(predictions=predictions)

0 commit comments

Comments
 (0)