Skip to content

Commit b8b3b97

Browse files
committed
[fix] initial support - SAMv1
1 parent 8fd0512 commit b8b3b97

File tree

10 files changed

+132
-131
lines changed

10 files changed

+132
-131
lines changed

cfgs/pipeline/remote_inference.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ codec:
2525
skip_n_frames: 0 # This is encoder only option
2626
n_frames_to_be_encoded: -1 #(-1 = encode all input), This is encoder only option
2727
measure_complexity: "${codec.mac_computation}"
28+
vcm_mode: "{codec.vcm_mode}"
29+
output10b: "{codec.output10b}"
2830
nn_task:
2931
dump_results: False
3032
output_results_dir: "${codec.output_dir}/output_results"

cfgs/vision_model/default.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ mask_rcnn_R_50_FPN_3x:
2525

2626
sam_vit_h_4b8939:
2727
model_path_prefix: ${..model_root_path}
28-
weights: "weights/sam/sam_vit_h_4b8939.pth"
28+
cfg: "Built-in configurations"
29+
weights: "weights/segment_anything/sam_vit_h_4b8939.pth"
2930
splits: "imgenc"
3031

3132
mask_rcnn_X_101_32x8d_FPN_3x:

compressai_vision/codecs/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def decode(
143143
org_img_size: Dict = None,
144144
remote_inference=False,
145145
vcm_mode=False,
146+
output10b=False,
146147
):
147148
del org_img_size
148149
del file_prefix # used in other codecs that write log files

compressai_vision/codecs/ffmpeg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def decode(
287287
org_img_size: Dict = None,
288288
remote_inference=False,
289289
vcm_mode=False,
290+
output10b=False,
290291
) -> bool:
291292
"""
292293
Decodes a bitstream into video frames and extract features from the decoded frames.

compressai_vision/datasets/image.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,10 @@ def __init__(self, root, dataset_name, imgs_folder, **kwargs):
293293
self.collate_fn = bypass_collator
294294

295295
_dataset = DatasetFromList(self.dataset, copy=False)
296-
297-
# if kwargs["linear_mapper"] is True:
298-
# mapper = LinearMapper()
299-
# else:
300-
# assert (
301-
# kwargs["cfg"] is not None
302-
# ), "A proper mapper information via cfg must be provided"
303-
# mapper = DatasetMapper(kwargs["cfg"], False)
304-
mapper = SAMCustomMapper(kwargs["patch_size"])
296+
mapper = SAMCustomMapper()
305297

306298
self.mapDataset = MapDataset(_dataset, mapper)
307-
self._org_mapper_func = PicklableWrapper(SAMCustomMapper(kwargs["patch_size"]))
299+
self._org_mapper_func = PicklableWrapper(SAMCustomMapper())
308300

309301
metaData = MetadataCatalog.get(dataset_name)
310302
try:

compressai_vision/datasets/utils.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from jde.utils.datasets import letterbox
3939
from mmpose.structures.bbox import get_warp_matrix
40+
from segment_anything.utils.transforms import ResizeLongestSide
4041
from torch.nn import functional as F
4142
from torchvision import transforms
4243

@@ -307,15 +308,13 @@ def __call__(self, dataset_dict):
307308

308309

309310
class SAMCustomMapper:
310-
def __init__(self, img_size=[1024, 1024]):
311+
def __init__(self, img_size=1024):
311312
"""
312313
Args:
313-
img_size: expected input size (Height, Width)
314+
img_size: single value - target size to SAM as input
314315
"""
315-
self.height = 1024
316-
self.width = 1024
317-
self.pixel_mean = [123.675, 116.28, 103.53]
318-
self.pixel_std = [58.395, 57.12, 57.375]
316+
self.target_size = img_size
317+
self.transform = ResizeLongestSide(img_size)
319318

320319
def __call__(self, dataset_dict):
321320
"""
@@ -337,17 +336,14 @@ def __call__(self, dataset_dict):
337336

338337
h = dataset_dict["height"]
339338
w = dataset_dict["width"]
340-
org_img = (org_img - self.pixel_mean) / self.pixel_std
341-
342-
padh = self.height - h # self.image_encoder.img_size - h
343-
padw = self.width - w
344-
image = torch.tensor(org_img)
345-
image = image.unsqueeze(-1)
346-
image = image.permute(3, 2, 0, 1)
347-
image = F.pad(image, (0, padw, 0, padh))
348-
image = image.to(torch.float32)
349-
# to tensor
350-
dataset_dict["image"] = image
339+
340+
# BGR --> RGB (SAM requires RGB input)
341+
org_img = org_img[..., ::-1]
342+
input_image = self.transform.apply_image(org_img)
343+
input_image = torch.tensor(input_image)
344+
input_image = input_image.permute(2, 0, 1).contiguous()[None, :, :, :]
345+
346+
dataset_dict["image"] = input_image
351347

352348
return dataset_dict
353349

compressai_vision/model_wrappers/sam.py

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
from detectron2.structures import ImageList, Instances
1515
from segment_anything import ( # , Instances
16-
SamAutomaticMaskGenerator,
17-
SamPredictor,
16+
# SamAutomaticMaskGenerator,
17+
# SamPredictor,
1818
sam_model_registry,
1919
)
2020
from torch.nn import functional as F
@@ -50,6 +50,7 @@ def __repr__(self):
5050

5151

5252
def mask_to_bbx(mask):
53+
mask = mask.cpu()
5354
mask = np.array(mask)
5455
mask = np.squeeze(mask)
5556
h, w = mask.shape[-2:]
@@ -81,17 +82,26 @@ class SAM(BaseWrapper):
8182
def __init__(self, device: str, **kwargs):
8283
super().__init__(device)
8384

85+
_path_prefix = (
86+
f"{root_path}"
87+
if kwargs["model_path_prefix"] == "default"
88+
else kwargs["model_path_prefix"]
89+
)
90+
self.model_info = {
91+
"cfg": f"{_path_prefix}/{kwargs['cfg']}",
92+
"weights": f"{_path_prefix}/{kwargs['weights']}",
93+
}
94+
8495
self.model = (
85-
sam_model_registry["vit_h"](checkpoint=kwargs["weights"]).to(device).eval()
96+
sam_model_registry["vit_h"](checkpoint=self.model_info["weights"])
97+
.to(device)
98+
.eval()
8699
)
87-
self.model.load_state_dict(torch.load(kwargs["weights"]))
88100

89-
self.backbone = self.model.image_encoder
101+
self.image_encoder = self.model.image_encoder
90102
self.prompt_encoder = self.model.prompt_encoder
91103
self.head = self.model.mask_decoder
92104

93-
# SamPredictor(self.model)
94-
# print(SamPredictor)
95105
self.supported_split_points = Split_Points
96106

97107
assert "splits" in kwargs, "Split layer ids must be provided"
@@ -106,18 +116,31 @@ def __init__(self, device: str, **kwargs):
106116
zip(self.split_layer_list, [None] * len(self.split_layer_list))
107117
)
108118

109-
self.annotation_file = "/o/projects/proj-river/ctc_sequences/vcm_testdata/samtest/annotations/mpeg-oiv6-segmentation-coco_fortest.json"
110-
111119
@property
112120
def SPLIT_IMGENC(self):
113121
return str(self.supported_split_points.ImageEncoder)
114122

115-
def input_to_features(self, x, device: str) -> Dict:
123+
@staticmethod
124+
def prompt_inputs(file_name):
125+
# [TODO] should be improved...
126+
prompt_link = file_name.replace("/images/", "/prompts/").replace(".jpg", ".txt")
127+
128+
with open(prompt_link, "r") as f:
129+
line = f.readline()
130+
# first_two = list(map(int, line.strip().split()[:2]))
131+
parts = line.strip().split()
132+
prompts = list(map(int, parts[:2]))
133+
object_classes = [int(line.strip().split()[-1])]
134+
135+
return prompts, object_classes
136+
137+
def input_to_features(self, x: list, device: str) -> Dict:
116138
"""Computes deep features at the intermediate layer(s) all the way from the input"""
117139
self.model = self.model.to(device).eval()
140+
assert isinstance(x, list) and len(x) == 1
118141

119142
if self.split_id == self.SPLIT_IMGENC:
120-
return self._input_to_image_encoder(x)
143+
return self._input_to_image_encoder(x, device)
121144
else:
122145
self.logger.error(f"Not supported split point {self.split_id}")
123146

@@ -129,48 +152,37 @@ def features_to_output(self, x: Dict, device: str):
129152
self.model = self.model.to(device).eval()
130153

131154
if self.split_id == self.SPLIT_IMGENC:
155+
assert "file_name" in x
156+
157+
prompts, object_classes = self.prompt_inputs(x["file_name"])
158+
132159
return self._image_encoder_to_output(
133160
x["data"],
134161
x["org_input_size"],
135162
x["input_size"],
136-
x["prompts"],
137-
x["object_classes"],
163+
prompts,
164+
object_classes,
165+
device,
138166
)
139167
else:
140168
self.logger.error(f"Not supported split points {self.split_id}")
141169

142170
raise NotImplementedError
143171

144172
@torch.no_grad()
145-
def _input_to_image_encoder(self, x):
173+
def _input_to_image_encoder(self, x, device):
146174
"""Computes and return encoded image all the way from the input"""
147-
# TODO pre_processing
148-
# print("AAAAA _input_to_image_encoder", x ,'\n')
149-
# imgs = ImageList(x)
150-
imgs = x[0]["image"]
151-
feature = {}
152-
feature["backbone"] = self.backbone(imgs)
153-
154-
prompt_link = (
155-
x[0]["file_name"].replace("/images/", "/prompts/").replace(".jpg", ".txt")
156-
)
157-
# print("AAAAA prompt_link", prompt_link)
158-
159-
with open(prompt_link, "r") as f:
160-
line = f.readline()
161-
# first_two = list(map(int, line.strip().split()[:2]))
162-
parts = line.strip().split()
163-
prompts = list(map(int, parts[:2]))
164-
object_classes = [int(line.strip().split()[-1])]
175+
assert len(x) == 1
165176

166-
image_sizes = [x[0]["height"], x[0]["width"]]
167-
# print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
177+
img = x[0]["image"].to(device)
178+
input_size = list(img.size()[2:])
179+
feature = {}
180+
input_img = self.model.preprocess(img)
181+
feature["backbone"] = self.image_encoder(input_img)
168182

169183
return {
170184
"data": feature,
171-
"input_size": image_sizes,
172-
"prompts": prompts,
173-
"object_classes": object_classes,
185+
"input_size": input_size,
174186
}
175187

176188
@torch.no_grad()
@@ -181,45 +193,6 @@ def get_input_size(self, x):
181193
image_sizes = [x[0]["height"], x[0]["width"]]
182194
return image_sizes # [1024, 1024]
183195

184-
@torch.no_grad()
185-
def get_prompts(self, x):
186-
"""Computes prompts"""
187-
prompt_link = (
188-
x[0]["file_name"].replace("/images/", "/prompts/").replace(".jpg", ".txt")
189-
)
190-
# print("AAAAA prompt_link", prompt_link)
191-
192-
with open(prompt_link, "r") as f:
193-
line = f.readline()
194-
# first_two = list(map(int, line.strip().split()[:2]))
195-
parts = line.strip().split()
196-
prompts = list(map(int, parts[:2]))
197-
object_classes = [int(line.strip().split()[-1])]
198-
199-
image_sizes = [x[0]["height"], x[0]["width"]]
200-
# print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
201-
202-
return prompts
203-
204-
@torch.no_grad()
205-
def get_object_classes(self, x):
206-
"""Computes input image size to the network"""
207-
prompt_link = (
208-
x[0]["file_name"].replace("/images/", "/prompts/").replace(".jpg", ".txt")
209-
)
210-
# print("AAAAA prompt_link", prompt_link)
211-
212-
with open(prompt_link, "r") as f:
213-
line = f.readline()
214-
# first_two = list(map(int, line.strip().split()[:2]))
215-
parts = line.strip().split()
216-
prompts = list(map(int, parts[:2]))
217-
object_classes = [int(line.strip().split()[-1])]
218-
219-
image_sizes = [x[0]["height"], x[0]["width"]]
220-
# print("AAAAA image_sizes", image_sizes, int(image_sizes[0]) * int(image_sizes[1])),
221-
return object_classes
222-
223196
@torch.no_grad()
224197
def _image_encoder_to_output(
225198
self,
@@ -228,6 +201,7 @@ def _image_encoder_to_output(
228201
input_img_size: List,
229202
prompts: List,
230203
object_classes: List,
204+
device,
231205
):
232206
"""
233207
performs downstream task using the encoded image feature
@@ -237,7 +211,7 @@ def _image_encoder_to_output(
237211

238212
input_points = [prompts] # [[469, 295]] #prompts["points"]
239213
input_points = np.array(input_points)
240-
input_points_ = torch.tensor(input_points)
214+
input_points_ = torch.tensor(input_points, device=device)
241215
input_points_ = input_points_.unsqueeze(-1)
242216
input_points_ = input_points_.permute(2, 0, 1)
243217

@@ -246,7 +220,7 @@ def _image_encoder_to_output(
246220
input_labels_ = input_labels_.unsqueeze(-1)
247221
input_labels_ = input_labels_.permute(1, 0)
248222

249-
points = (torch.tensor(input_points_), torch.tensor(input_labels_))
223+
points = (input_points_, torch.tensor(input_labels_, device=device))
250224
prompt_feature = self.prompt_encoder(points=points, boxes=None, masks=None)
251225
image_pe = self.prompt_encoder.get_dense_pe()
252226

@@ -261,7 +235,7 @@ def _image_encoder_to_output(
261235
# post process mask
262236
masks = F.interpolate(
263237
low_res_masks,
264-
(1024, 1024),
238+
(self.image_encoder.img_size, self.image_encoder.img_size),
265239
mode="bilinear",
266240
align_corners=False,
267241
)
@@ -270,7 +244,7 @@ def _image_encoder_to_output(
270244
] # [..., : 793, : 1024]
271245
masks = F.interpolate(
272246
masks,
273-
(input_img_size[0], input_img_size[1]),
247+
(org_img_size["height"], org_img_size["width"]),
274248
mode="bilinear",
275249
align_corners=False,
276250
)
@@ -314,14 +288,26 @@ def _image_encoder_to_output(
314288
def forward(self, x):
315289
"""Complete the downstream task with end-to-end manner all the way from the input"""
316290
# test
317-
enc = self._input_to_image_encoder(self, x)
318-
dec = self._image_encoder_to_output(enc)
291+
enc_res = self._input_to_image_encoder([x], self.device)
319292

320-
return dec
293+
# suppose that the order of keys and values is matched
294+
enc_res["data"] = {
295+
k: v.to(device=self.device)
296+
for k, v in zip(self.split_layer_list, enc_res["data"].values())
297+
}
298+
299+
prompts, object_classes = self.prompt_inputs(x["file_name"])
300+
301+
dec_res = self._image_encoder_to_output(
302+
enc_res["data"],
303+
{"height": x["height"], "width": x["width"]},
304+
enc_res["input_size"],
305+
prompts,
306+
object_classes,
307+
device=self.device,
308+
)
321309

322-
# @property
323-
# def cfg(self):
324-
# return self._cfg
310+
return dec_res
325311

326312

327313
@register_vision_model("sam_vit_h_4b8939")

compressai_vision/pipelines/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,6 @@ def _from_features_to_output(
374374
for k, v in zip(vision_model.split_layer_list, x["data"].values())
375375
}
376376

377-
if "prompts" in x:
378-
x["prompts"] = x["prompts"]
379-
if "object_classes" in x:
380-
x["object_classes"] = x["object_classes"]
381-
382377
results = vision_model.features_to_output(x, self.device_nn_part2)
383378
if self.configs["nn_task_part2"].dump_results:
384379
self._create_folder(output_results_dir)

0 commit comments

Comments
 (0)