Skip to content

Commit 4a8b079

Browse files
authored
Add OTX optimize for visual prompting task (#2318)
* Initial commit * Update block * (WIP) otx optimize * Fix * WIP * Update configs & exported outputs * Remove unused modules for torch * Add unit tests * pre-commit * Update CHANGELOG
1 parent 603efca commit 4a8b079

File tree

16 files changed

+396
-205
lines changed

16 files changed

+396
-205
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ All notable changes to this project will be documented in this file.
1515
- Add new visual prompting task: train/eval (https://github.com/openvinotoolkit/training_extensions/pull/2203)
1616
- Add new visual prompting task: export (https://github.com/openvinotoolkit/training_extensions/pull/2274)
1717
- Add new visual prompting task: deploy (https://github.com/openvinotoolkit/training_extensions/pull/2311)
18+
- Add new visual prompting task: optimize (PTQ) (https://github.com/openvinotoolkit/training_extensions/pull/2318)
1819
- Add new object detector ResNeXt101-ATSS (<https://github.com/openvinotoolkit/training_extensions/pull/2309>)
1920

2021
### Enhancements

src/otx/algorithms/visual_prompting/adapters/openvino/model_wrappers/openvino_models.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from openvino.model_api.models import ImageModel, SegmentationModel
2424
from openvino.model_api.models.types import NumericalValue, StringValue
2525

26+
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.pipelines import ResizeLongestSide
2627
from otx.api.utils.segmentation_utils import create_hard_prediction_from_soft_prediction
2728

2829

@@ -40,13 +41,20 @@ def parameters(cls) -> Dict[str, Any]: # noqa: D102
4041
parameters.update(
4142
{
4243
"resize_type": StringValue(default_value="fit_to_window"),
44+
"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048),
4345
}
4446
)
4547
return parameters
4648

47-
def preprocess(self, inputs: np.ndarray) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
49+
def preprocess(
50+
self, inputs: np.ndarray, extra_processing: bool = False
51+
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
4852
"""Update meta for image encoder."""
4953
dict_inputs, meta = super().preprocess(inputs)
54+
if extra_processing:
55+
dict_inputs["images"] = ResizeLongestSide.apply_image(dict_inputs["images"][0], self.image_size).transpose(
56+
2, 0, 1
57+
)[None]
5058
meta["resize_type"] = self.resize_type
5159
return dict_inputs, meta
5260

@@ -63,14 +71,16 @@ def __init__(
6371
preload: bool = False,
6472
):
6573
super().__init__(model_adapter, configuration, preload)
66-
self.output_blob_name = "low_res_masks"
6774

6875
@classmethod
6976
def parameters(cls): # noqa: D102
7077
parameters = super().parameters()
7178
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
7279
return parameters
7380

81+
def _get_outputs(self):
82+
return "low_res_masks"
83+
7484
def preprocess(self, inputs: Dict[str, Any], meta: Dict[str, Any]):
7585
"""Preprocess prompts."""
7686
processed_prompts = []

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/config/visual_prompting_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def update_visual_prompting_config(
9797
if groups:
9898
for group in groups:
9999
if group in ["learning_parameters", "nncf_optimization", "pot_parameters", "postprocessing"]:
100-
if group in ["nncf_optimization", "pot_parameters"]:
101-
# TODO (sungchul): Consider pot_parameters, nncf_optimization, and postprocessing
100+
if group in ["nncf_optimization"]:
101+
# TODO (sungchul): Consider nncf_optimization
102102
logger.warning(f"{group} will be implemented.")
103103
continue
104104
update_visual_prompting_config(visual_prompting_config, getattr(otx_config, group))

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/pipelines/sam_transforms.py

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import numpy as np
1111
import torch
1212
from torch import Tensor
13-
from torch.nn import functional as F
1413
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
1514

1615

@@ -36,24 +35,26 @@ def __call__(self, item: Dict[str, Union[List, Tensor]]) -> Dict[str, Union[List
3635
Dict[str, Union[List, Tensor]]: Dictionary of batch data.
3736
"""
3837
item["images"] = torch.as_tensor(
39-
self.apply_image(item["images"]).transpose((2, 0, 1)), dtype=torch.get_default_dtype()
38+
self.apply_image(item["images"], self.target_length).transpose((2, 0, 1)), dtype=torch.get_default_dtype()
4039
)
4140
item["gt_masks"] = [torch.as_tensor(gt_mask) for gt_mask in item["gt_masks"]]
4241
item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"])
4342
if item["points"]:
4443
item["points"] = self.apply_coords(item["points"], item["original_size"])
4544
return item
4645

47-
def apply_image(self, image: np.ndarray) -> np.ndarray:
46+
@classmethod
47+
def apply_image(cls, image: np.ndarray, target_length: int) -> np.ndarray:
4848
"""Expects a numpy array with shape HxWxC in uint8 format.
4949
5050
Args:
5151
image (np.ndarray): Image array.
52+
target_length (int): The length of the longest side of the image.
5253
5354
Returns:
5455
np.ndarray: Resized image.
5556
"""
56-
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
57+
target_size = cls.get_preprocess_shape(image.shape[0], image.shape[1], target_length)
5758
return np.array(resize(to_pil_image(image), target_size))
5859

5960
def apply_coords(self, coords: np.ndarray, original_size: Union[List[Any], Tensor]) -> np.ndarray:
@@ -88,56 +89,6 @@ def apply_boxes(self, boxes: np.ndarray, original_size: Union[List[Any], Tensor]
8889
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
8990
return boxes.reshape(-1, 4)
9091

91-
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
92-
"""Expects batched images with shape BxCxHxW and float format.
93-
94-
This transformation may not exactly match apply_image.
95-
apply_image is the transformation expected by the model.
96-
97-
Args:
98-
image (torch.Tensor): Image tensor.
99-
100-
Returns:
101-
torch.Tensor: Resized image.
102-
"""
103-
# Expects an image in BCHW format. May not exactly match apply_image.
104-
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
105-
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
106-
107-
def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
108-
"""Expects a torch tensor with length 2 in the last dimension.
109-
110-
Requires the original image size in (H, W) format.
111-
112-
Args:
113-
coords (torch.Tensor): Coordinates tensor.
114-
original_size (Tuple[int, ...]): Original size of image.
115-
116-
Returns:
117-
torch.Tensor: Resized coordinates.
118-
"""
119-
old_h, old_w = original_size
120-
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
121-
coords = deepcopy(coords).to(torch.float)
122-
coords[..., 0] = coords[..., 0] * (new_w / old_w)
123-
coords[..., 1] = coords[..., 1] * (new_h / old_h)
124-
return coords
125-
126-
def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
127-
"""Expects a torch tensor with shape Bx4.
128-
129-
Requires the original image size in (H, W) format.
130-
131-
Args:
132-
boxes (torch.Tensor): Boxes tensor.
133-
original_size (Tuple[int, ...]): Original size of image.
134-
135-
Returns:
136-
torch.Tensor: Resized boxes.
137-
"""
138-
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
139-
return boxes.reshape(-1, 4)
140-
14192
@staticmethod
14293
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
14394
"""Compute the output size given input size and target long side length.

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ def replace_state_dict_keys(state_dict, revise_keys):
174174
state_dict = replace_state_dict_keys(state_dict, revise_keys)
175175
self.load_state_dict(state_dict)
176176

177-
#################################################
178-
# forward for inference (export/deploy) #
179-
#################################################
177+
##########################################################
178+
# forward for inference (export/deploy/optimize) #
179+
##########################################################
180180
@torch.no_grad()
181181
def forward(
182182
self,
@@ -185,7 +185,7 @@ def forward(
185185
point_labels: Tensor,
186186
mask_input: Tensor,
187187
has_mask_input: Tensor,
188-
orig_size: Tensor,
188+
# orig_size: Tensor,
189189
):
190190
"""Forward method for SAM inference (export/deploy).
191191
@@ -227,16 +227,18 @@ def forward(
227227
if self.config.model.return_single_mask:
228228
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
229229

230-
upscaled_masks = self.mask_postprocessing(masks, orig_size[0])
230+
return scores, masks
231+
# TODO (sungchul): apply inner postprocessing
232+
# upscaled_masks = self.mask_postprocessing(masks, orig_size[0])
231233

232-
if self.config.model.return_extra_metrics:
233-
stability_scores = self.calculate_stability_score(
234-
upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset
235-
)
236-
areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1)
237-
return upscaled_masks, scores, stability_scores, areas, masks
234+
# if self.config.model.return_extra_metrics:
235+
# stability_scores = self.calculate_stability_score(
236+
# upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset
237+
# )
238+
# areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1)
239+
# return upscaled_masks, scores, stability_scores, areas, masks
238240

239-
return upscaled_masks, scores, masks
241+
# return upscaled_masks, scores, masks
240242

241243
def _embed_points(self, point_coords: Tensor, point_labels: Tensor) -> Tensor:
242244
"""Embed sparse input prompts.

src/otx/algorithms/visual_prompting/configs/base/configuration.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
from attr import attrs
1919

20-
from otx.algorithms.common.configs import BaseConfig
20+
from otx.algorithms.common.configs import BaseConfig, POTQuantizationPreset
2121
from otx.api.configuration.elements import (
2222
ParameterGroup,
2323
add_parameter_group,
24+
boolean_attribute,
2425
configurable_boolean,
2526
configurable_float,
2627
configurable_integer,
28+
selectable,
2729
string_attribute,
2830
)
2931
from otx.api.configuration.model_lifecycle import ModelLifecycle
@@ -95,5 +97,20 @@ class __Postprocessing(ParameterGroup):
9597
affects_outcome_of=ModelLifecycle.INFERENCE,
9698
)
9799

100+
@attrs
101+
class __POTParameter(BaseConfig.BasePOTParameter):
102+
header = string_attribute("POT Parameters")
103+
description = header
104+
visible_in_ui = boolean_attribute(False)
105+
106+
preset = selectable(
107+
default_value=POTQuantizationPreset.MIXED,
108+
header="Preset",
109+
description="Quantization preset that defines quantization scheme",
110+
editable=True,
111+
visible_in_ui=True,
112+
)
113+
98114
learning_parameters = add_parameter_group(__LearningParameters)
99115
postprocessing = add_parameter_group(__Postprocessing)
116+
pot_parameters = add_parameter_group(__POTParameter)

src/otx/algorithms/visual_prompting/configs/configuration.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ pot_parameters:
148148
affects_outcome_of: NONE
149149
auto_hpo_state: not_possible
150150
auto_hpo_value: null
151-
default_value: Performance
151+
default_value: Mixed
152152
description: Quantization preset that defines quantization scheme
153153
editable: true
154154
enum_name: POTQuantizationPreset
@@ -162,7 +162,7 @@ pot_parameters:
162162
operator: AND
163163
rules: []
164164
type: UI_RULES
165-
value: Performance
165+
value: Mixed
166166
visible_in_ui: true
167167
warning: null
168168
stat_subset_size:

src/otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
dataset:
22
task: visual_prompting
3-
train_batch_size: 2
3+
train_batch_size: 4
44
val_batch_size: 1
55
test_batch_size: 1
66
num_workers: 4
@@ -35,7 +35,7 @@ model:
3535

3636
optimizer:
3737
name: Adam
38-
lr: 0.0001
38+
lr: 0.000001
3939

4040
callback:
4141
checkpoint: # arguments for ModelCheckpoint

src/otx/algorithms/visual_prompting/configs/sam_vit_b/configuration.yaml

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -85,70 +85,14 @@ learning_parameters:
8585
visible_in_ui: true
8686
warning: null
8787
auto_hpo_state: NOT_POSSIBLE
88-
nncf_optimization:
89-
description: Optimization by NNCF
90-
enable_pruning:
91-
affects_outcome_of: NONE
92-
auto_hpo_state: not_possible
93-
auto_hpo_value: null
94-
default_value: false
95-
description: Enable filter pruning algorithm
96-
editable: true
97-
header: Enable filter pruning algorithm
98-
type: BOOLEAN
99-
ui_rules:
100-
action: DISABLE_EDITING
101-
operator: AND
102-
rules: []
103-
type: UI_RULES
104-
value: false
105-
visible_in_ui: true
106-
warning: null
107-
enable_quantization:
108-
affects_outcome_of: NONE
109-
auto_hpo_state: not_possible
110-
auto_hpo_value: null
111-
default_value: true
112-
description: Enable quantization algorithm
113-
editable: true
114-
header: Enable quantization algorithm
115-
type: BOOLEAN
116-
ui_rules:
117-
action: DISABLE_EDITING
118-
operator: AND
119-
rules: []
120-
type: UI_RULES
121-
value: true
122-
visible_in_ui: true
123-
warning: null
124-
header: Optimization by NNCF
125-
pruning_supported:
126-
affects_outcome_of: TRAINING
127-
auto_hpo_state: not_possible
128-
auto_hpo_value: null
129-
default_value: false
130-
description: Whether filter pruning is supported
131-
editable: false
132-
header: Whether filter pruning is supported
133-
type: BOOLEAN
134-
ui_rules:
135-
action: DISABLE_EDITING
136-
operator: AND
137-
rules: []
138-
type: UI_RULES
139-
value: false
140-
visible_in_ui: false
141-
warning: null
142-
type: PARAMETER_GROUP
143-
visible_in_ui: true
14488
pot_parameters:
14589
description: POT Parameters
14690
header: POT Parameters
14791
preset:
14892
affects_outcome_of: NONE
14993
auto_hpo_state: not_possible
15094
auto_hpo_value: null
151-
default_value: Performance
95+
default_value: Mixed
15296
description: Quantization preset that defines quantization scheme
15397
editable: true
15498
enum_name: POTQuantizationPreset
@@ -162,7 +106,7 @@ pot_parameters:
162106
operator: AND
163107
rules: []
164108
type: UI_RULES
165-
value: Performance
109+
value: Mixed
166110
visible_in_ui: true
167111
warning: null
168112
stat_subset_size:

src/otx/algorithms/visual_prompting/tasks/inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,8 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]):
281281
"point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float),
282282
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
283283
"has_mask_input": torch.tensor([[1]], dtype=torch.float),
284-
"orig_size": torch.tensor([[height, width]], dtype=torch.float),
285284
}
286-
output_names = ["masks", "iou_predictions", "low_res_masks"]
285+
output_names = ["iou_predictions", "low_res_masks"]
287286
model_to_export = self.model
288287

289288
with warnings.catch_warnings():

0 commit comments

Comments
 (0)