Skip to content

Commit 1b523f1

Browse files
chyomin06fracape
authored andcommitted
[feat] support for sam2 image model, initial implementation
1 parent d47c579 commit 1b523f1

File tree

8 files changed

+444
-40
lines changed

8 files changed

+444
-40
lines changed

cfgs/vision_model/default.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ sam_vit_h_4b8939:
4747
weights: "weights/segment_anything/sam_vit_h_4b8939.pth"
4848
splits: "imgenc"
4949

50+
sam2_hiera_image_model:
51+
model_path_prefix: ${..model_root_path}
52+
cfg: "models/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml"
53+
weights: "weights/sam2/sam2.1_hiera_base_plus.pt"
54+
# weights: "weights/sam2/sam2.1_hiera_large.pt"
55+
# weights: "weights/sam2/sam2.1_hiera_small.pt"
56+
# weights: "weights/sam2/sam2.1_hiera_tiny.pt"
57+
splits: "backbone"
58+
59+
sam2_hiera_video_model:
60+
model_path_prefix: ${..model_root_path}
61+
cfg: "models/sam2/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml"
62+
weights: "weights/sam2/sam2.1_hiera_base_plus.pt"
63+
# weights: "weights/sam2/sam2.1_hiera_large.pt"
64+
# weights: "weights/sam2/sam2.1_hiera_small.pt"
65+
# weights: "weights/sam2/sam2.1_hiera_tiny.pt"
66+
splits: "backbone"
67+
5068
jde_1088x608:
5169
model_path_prefix: ${..model_root_path}
5270
cfg: "models/Towards-Realtime-MOT/cfg/yolov3_1088x608.cfg"

compressai_vision/datasets/image.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from detectron2.data.samplers import InferenceSampler
4646
from detectron2.data.transforms import AugmentationList
4747
from detectron2.utils.serialize import PicklableWrapper
48-
from jde.utils.io import read_results
4948
from PIL import Image
5049
from torch.utils.data import Dataset
5150

@@ -308,10 +307,12 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
308307
self.collate_fn = bypass_collator
309308

310309
_dataset = DatasetFromList(self.dataset, copy=False)
311-
mapper = SAMCustomMapper()
310+
mapper = SAMCustomMapper(
311+
augmentation_bypass=kwargs["input_augmentation_bypass"]
312+
)
312313

313314
self.mapDataset = MapDataset(_dataset, mapper)
314-
self._org_mapper_func = PicklableWrapper(SAMCustomMapper())
315+
self._org_mapper_func = PicklableWrapper(mapper)
315316

316317
metaData = MetadataCatalog.get(dataset_name)
317318
try:
@@ -551,6 +552,7 @@ def __init__(
551552
dataset_name=dataset_name,
552553
ext=ext,
553554
)
555+
from jde.utils.io import read_results
554556

555557
self.data_type = "mot"
556558
gt_frame_dict = read_results(

compressai_vision/datasets/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,14 @@ def __call__(self, dataset_dict):
307307

308308

309309
class SAMCustomMapper:
310-
def __init__(self, img_size=1024):
310+
def __init__(self, augmentation_bypass=False, img_size=1024):
311311
"""
312312
Args:
313313
img_size: single value - target size to SAM as input
314314
"""
315315
from segment_anything.utils.transforms import ResizeLongestSide
316316

317+
self.augmentation_bypass = augmentation_bypass
317318
self.target_size = img_size
318319
self.transform = ResizeLongestSide(img_size)
319320

@@ -335,16 +336,17 @@ def __call__(self, dataset_dict):
335336
org_img = cv2.imread(dataset_dict["file_name"]) # return img in BGR by default
336337
dataset_dict["height"], dataset_dict["width"], _ = org_img.shape
337338

338-
# h = dataset_dict["height"]
339-
# w = dataset_dict["width"]
340-
341339
# BGR --> RGB (SAM requires RGB input)
342340
org_img = org_img[..., ::-1]
343-
input_image = self.transform.apply_image(org_img)
344-
input_image = torch.tensor(input_image)
345-
input_image = input_image.permute(2, 0, 1).contiguous()[None, :, :, :]
346341

347-
dataset_dict["image"] = input_image
342+
if self.augmentation_bypass:
343+
dataset_dict["image"] = org_img.copy()
344+
else:
345+
input_image = self.transform.apply_image(org_img)
346+
input_image = torch.tensor(input_image)
347+
input_image = input_image.permute(2, 0, 1).contiguous()[None, :, :, :]
348+
349+
dataset_dict["image"] = input_image
348350

349351
return dataset_dict
350352

compressai_vision/model_wrappers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from . import detectron2, jde, rtmo, sam, yolox
30+
from . import detectron2, jde, rtmo, sam, sam2, yolox
3131
from .base_wrapper import BaseWrapper
3232

3333
__all__ = ["BaseWrapper"]

compressai_vision/model_wrappers/sam.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
# Copyright (c) 2025, InterDigital Communications, Inc
2+
# All rights reserved.
3+
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted (subject to the limitations in the disclaimer
6+
# below) provided that the following conditions are met:
7+
8+
# * Redistributions of source code must retain the above copyright notice,
9+
# this list of conditions and the following disclaimer.
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
# * Neither the name of InterDigital Communications, Inc nor the names of its
14+
# contributors may be used to endorse or promote products derived from this
15+
# software without specific prior written permission.
16+
17+
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18+
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19+
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20+
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21+
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26+
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27+
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28+
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
130
import base64
231
import csv
332
import os
@@ -44,8 +73,10 @@ def __repr__(self):
4473

4574

4675
def mask_to_bbx(mask):
47-
mask = mask.cpu()
48-
mask = np.array(mask)
76+
if not isinstance(mask, np.ndarray):
77+
mask = mask.cpu()
78+
mask = np.array(mask)
79+
4980
mask = np.squeeze(mask)
5081
h, w = mask.shape[-2:]
5182
rows, cols = np.where(mask)
@@ -228,43 +259,31 @@ def _image_encoder_to_output(
228259
dense_prompt_embeddings=prompt_feature[1],
229260
multimask_output=False,
230261
)
231-
# print("len low_res_masks", len(low_res_masks))
262+
232263
# post process mask
233264
masks = F.interpolate(
234265
low_res_masks,
235266
(self.image_encoder.img_size, self.image_encoder.img_size),
236267
mode="bilinear",
237268
align_corners=False,
238269
)
239-
masks = masks[
240-
..., : input_img_size[0], : input_img_size[1]
241-
] # [..., : 793, : 1024]
270+
masks = masks[..., : input_img_size[0], : input_img_size[1]]
242271
masks = F.interpolate(
243272
masks,
244273
(org_img_size["height"], org_img_size["width"]),
245274
mode="bilinear",
246275
align_corners=False,
247276
)
248277

249-
# masks1 = self.postprocess_masks(
250-
# masks= low_res_masks,
251-
# input_size=input_img_size,
252-
# original_size=org_img_size,
253-
# )
254278
mask_threshold = 0.0
255279
masks = masks > mask_threshold
256-
# print("len masks", len(masks), masks[0].shape)
257-
# name = '/t/vic/hevc_simulations/rosen/build/main-20250423-sam1/masks' + str(input_img_size[0]) + '.pt'
258-
# torch.save(masks[0], name) #"/t/vic/hevc_simulations/rosen/build/main-20250423-sam1/masks.pt")
259280

260281
# post process result
261282
processed_results = []
262283
boxes = mask_to_bbx(masks[0])
263-
# print("boxes", boxes)
264284
boxes = Boxes(torch.tensor(np.array([boxes])))
265285
scores = torch.tensor([iou_pred])
266-
classes = torch.tensor(object_classes) # 48 for sandwich,
267-
# masks = torch.rand(1, 683, 1024) # Example binary mask
286+
classes = torch.tensor(object_classes)
268287

269288
from detectron2.structures import Instances
270289

@@ -273,14 +292,9 @@ def _image_encoder_to_output(
273292
instances.set("pred_boxes", boxes)
274293
instances.set("scores", scores)
275294
instances.set("pred_classes", classes)
276-
instances.set("pred_masks", masks[0]) # ✅ Now a real tensor
295+
instances.set("pred_masks", masks[0])
277296

278-
# Wrap in result
279-
# result = [f"{{'instances': {instances}}}"]
280-
# print("result", result)
281-
# print("instances", instances.get_fields().keys(), len(instances))
282297
processed_results.append({"instances": instances})
283-
# print("processed_results", len(processed_results))
284298
return processed_results
285299

286300
@torch.no_grad()

0 commit comments

Comments
 (0)