Skip to content

Commit bad2081

Browse files
authored
OTX export for visual prompting task (#2274)
* Enable to export models TODOs - check weights in eval - apply transforms * Update get_visual_promtping_config * Fix * Distinguish task mode when setting config * Fix * Combine stage's transform * Add sam ir paths * Insert transform into dataset * Add openvino entrypoint * (WIP) OpenVINOTask * Fix indent * Add set_metadata * Update load_model * Fix typo * Add normalize factors in config & add get_transform function * Enable to load pretrained weights using "pretrained" * Fix * Pull preprocessing functions out of dataset class * Edit updating metas * Remove legacy * Update to always train with original size * Move configuration.yaml * Update to match openvino input layout * Add model wrapper & update configuration * WIP * Remove adapter wrapper * Set two model wrappers * Set get_prompt as staticmethod * Set `VisualPromptingToAnnotationConverter` * Enable sync process * Remove `sam` in module name * Fix ov path * Update file names * Complete ov inference with sync process * Fix * Set sync inference as default - Implement async inference, but there is performance degradation issue and it is required to check the reason * Add warning about async * Move decoder preprocess * Remove async parts * Update init * Fix * Fix converter * Docstring * Update transform tests * Fix import path * Update test_segment_anything * Update path name * Add unit test for openvino task * Delete visual prompting converter in `create_converter` * Add unit test for model_wrappers * Avoid for export to omit original config.yaml * Add integration tests * Make pre-commit happy * Rebase * Update CHANGELOG * Fix * Fix unused logging
1 parent 50e4f35 commit bad2081

File tree

45 files changed

+2286
-446
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2286
-446
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ All notable changes to this project will be documented in this file.
1212
- Add per-class XAI saliency maps for Mask R-CNN model (https://github.com/openvinotoolkit/training_extensions/pull/2227)
1313
- Add new object detector Deformable DETR (<https://github.com/openvinotoolkit/training_extensions/pull/2249>)
1414
- Add new object detecotr DINO(<https://github.com/openvinotoolkit/training_extensions/pull/2266>)
15-
- Add new visual prompting task (https://github.com/openvinotoolkit/training_extensions/pull/2203)
15+
- Add new visual prompting task (https://github.com/openvinotoolkit/training_extensions/pull/2203), (https://github.com/openvinotoolkit/training_extensions/pull/2274)
1616

1717
### Enhancements
1818

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""OpenVINO modules for visual prompting task."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions
15+
# and limitations under the License.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Model Wrapper of OTX Visual Prompting."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions
15+
# and limitations under the License.
16+
17+
from typing import Any, Dict, Tuple
18+
19+
import cv2
20+
import numpy as np
21+
from openvino.model_zoo.model_api.models import ImageModel
22+
from openvino.model_zoo.model_api.models.types import NumericalValue
23+
24+
from otx.algorithms.segmentation.adapters.openvino.model_wrappers.blur import (
25+
BlurSegmentation,
26+
)
27+
from otx.api.utils.segmentation_utils import create_hard_prediction_from_soft_prediction
28+
29+
30+
class ImageEncoder(ImageModel):
31+
"""Image encoder class for visual prompting of openvino model wrapper."""
32+
33+
__model__ = "image_encoder"
34+
35+
@classmethod
36+
def parameters(cls) -> Dict[str, Any]: # noqa: D102
37+
parameters = super().parameters()
38+
parameters["resize_type"].default_value = "fit_to_window"
39+
parameters["mean_values"].default_value = [123.675, 116.28, 103.53]
40+
parameters["scale_values"].default_value = [58.395, 57.12, 57.375]
41+
return parameters
42+
43+
44+
class Decoder(BlurSegmentation):
45+
"""Decoder class for visual prompting of openvino model wrapper.
46+
47+
TODO (sungchul): change parent class
48+
"""
49+
50+
__model__ = "decoder"
51+
52+
def preprocess(self, bbox: np.ndarray, original_size: Tuple[int]) -> Dict[str, Any]:
53+
"""Ready decoder inputs."""
54+
point_coords = bbox.reshape((-1, 2, 2))
55+
point_labels = np.array([2, 3], dtype=np.float32).reshape((-1, 2))
56+
inputs_decoder = {
57+
"point_coords": point_coords,
58+
"point_labels": point_labels,
59+
# TODO (sungchul): how to generate mask_input and has_mask_input
60+
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
61+
"has_mask_input": np.zeros((1, 1), dtype=np.float32),
62+
"orig_size": np.array(original_size, dtype=np.float32).reshape((-1, 2)),
63+
}
64+
return inputs_decoder
65+
66+
@classmethod
67+
def parameters(cls): # noqa: D102
68+
parameters = super().parameters()
69+
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
70+
return parameters
71+
72+
def _get_inputs(self):
73+
"""Get input layer name and shape."""
74+
image_blob_names = [name for name in self.inputs.keys()]
75+
image_info_blob_names = []
76+
return image_blob_names, image_info_blob_names
77+
78+
def _get_outputs(self):
79+
"""Get output layer name and shape."""
80+
layer_name = "low_res_masks"
81+
layer_shape = self.outputs[layer_name].shape
82+
83+
if len(layer_shape) == 3:
84+
self.out_channels = 0
85+
elif len(layer_shape) == 4:
86+
self.out_channels = layer_shape[1]
87+
else:
88+
raise Exception(f"Unexpected output layer shape {layer_shape}. Only 4D and 3D output layers are supported")
89+
90+
return layer_name
91+
92+
def postprocess(self, outputs: Dict[str, np.ndarray], meta: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
93+
"""Postprocess to convert soft prediction to hard prediction.
94+
95+
Args:
96+
outputs (Dict[str, np.ndarray]): The output of the model.
97+
meta (Dict[str, Any]): Contain label and original size.
98+
99+
Returns:
100+
hard_prediction (np.ndarray): The hard prediction.
101+
soft_prediction (np.ndarray): Resized, cropped, and normalized soft prediction.
102+
"""
103+
104+
def sigmoid(x):
105+
return 1 / (1 + np.exp(-x))
106+
107+
soft_prediction = outputs[self.output_blob_name].squeeze()
108+
soft_prediction = self.resize_and_crop(soft_prediction, meta["original_size"])
109+
soft_prediction = sigmoid(soft_prediction)
110+
meta["soft_prediction"] = soft_prediction
111+
112+
hard_prediction = create_hard_prediction_from_soft_prediction(
113+
soft_prediction=soft_prediction,
114+
soft_threshold=self.soft_threshold,
115+
blur_strength=self.blur_strength,
116+
)
117+
118+
probability = max(min(float(outputs["iou_predictions"]), 1.0), 0.0)
119+
meta["label"].probability = probability
120+
121+
return hard_prediction, soft_prediction
122+
123+
def resize_and_crop(self, soft_prediction: np.ndarray, original_size: np.ndarray) -> np.ndarray:
124+
"""Resize and crop soft prediction.
125+
126+
Args:
127+
soft_prediction (np.ndarray): Predicted soft prediction with HxW shape.
128+
original_size (np.ndarray): The original image size.
129+
130+
Returns:
131+
final_soft_prediction (np.ndarray): Resized and cropped soft prediction for the original image.
132+
"""
133+
resized_soft_prediction = cv2.resize(
134+
soft_prediction, (self.image_size, self.image_size), 0, 0, interpolation=cv2.INTER_LINEAR
135+
)
136+
137+
prepadded_size = self.resize_longest_image_size(original_size, self.image_size).astype(np.int64)
138+
resized_cropped_soft_prediction = resized_soft_prediction[..., : prepadded_size[0], : prepadded_size[1]]
139+
140+
original_size = original_size.astype(np.int64)
141+
h, w = original_size[0], original_size[1]
142+
final_soft_prediction = cv2.resize(
143+
resized_cropped_soft_prediction, (w, h), 0, 0, interpolation=cv2.INTER_LINEAR
144+
)
145+
return final_soft_prediction
146+
147+
def resize_longest_image_size(self, original_size: np.ndarray, longest_side: int) -> np.ndarray:
148+
"""Resizes the longest side of the image to the given size.
149+
150+
Args:
151+
original_size (np.ndarray): The original image size with shape Bx2.
152+
longest_side (int): The size of the longest side.
153+
154+
Returns:
155+
transformed_size (np.ndarray): The transformed image size with shape Bx2.
156+
"""
157+
original_size = original_size.astype(np.float32)
158+
scale = longest_side / np.max(original_size)
159+
transformed_size = scale * original_size
160+
transformed_size = np.floor(transformed_size + 0.5).astype(np.int64)
161+
return transformed_size

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/callbacks/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,4 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17-
from .inference import InferenceCallback
18-
19-
__all__ = ["InferenceCallback"]
17+
from .inference import InferenceCallback # noqa: F401

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/config/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17-
from .visual_prompting_config import get_visual_promtping_config, update_visual_prompting_config
18-
19-
__all__ = ["get_visual_promtping_config", "update_visual_prompting_config"]
17+
from .visual_prompting_config import (
18+
get_visual_promtping_config, # noqa: F401
19+
update_visual_prompting_config, # noqa: F401
20+
)

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/config/visual_prompting_config.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,52 +29,53 @@
2929
def get_visual_promtping_config(
3030
task_name: str,
3131
otx_config: ConfigurableParameters,
32-
output_path: Path,
32+
config_dir: str,
33+
mode: str = "train",
3334
model_checkpoint: Optional[str] = None,
3435
resume_from_checkpoint: Optional[str] = None,
35-
config_filename: str = "config",
36-
config_file_extension: str = "yaml",
3736
) -> Union[DictConfig, ListConfig]:
3837
"""Get visual prompting configuration.
3938
4039
Create a visual prompting config object that matches the values specified in the
4140
OTX config.
4241
4342
Args:
44-
task_name (str): Task name to load configuration from Visual Prompting.
45-
otx_config (ConfigurableParameters): OTX config object parsed from configuration.yaml file.
46-
output_path (Path): Path to save the configuration file.
43+
task_name (str): Task name to load configuration from visual prompting.
44+
otx_config (ConfigurableParameters): OTX config object parsed from `configuration.yaml` file.
45+
config_dir (str): Path to load raw `config.yaml` or save updated `config.yaml`.
46+
mode (str): Mode to run visual prompting task. Default: "train".
4747
model_checkpoint (Optional[str]): Path to the checkpoint to load the model weights.
4848
resume_from_checkpoint (Optional[str]): Path to the checkpoint to resume training.
49-
config_filename (str): Name of the configuration file, defaults to "config".
50-
config_file_extension (str): Extension of the configuration file, defaults to "yaml".
5149
5250
Returns:
5351
Union[DictConfig, ListConfig]: Visual prompting config object for the specified model type
5452
with overwritten default values.
5553
"""
56-
if os.path.isfile(os.path.join(output_path, "config.yaml")):
54+
if os.path.isfile(os.path.join(config_dir, "config.yaml")):
5755
# If there is already a config.yaml file in the output path, load it
58-
config_path = os.path.join(output_path, "config.yaml")
56+
config_path = os.path.join(config_dir, "config.yaml")
5957
else:
6058
# Load the default config.yaml file
59+
logger.info("[*] Load default config.yaml.")
6160
config_path = f"src/otx/algorithms/visual_prompting/configs/{task_name.lower()}/config.yaml"
6261

6362
config = OmegaConf.load(config_path)
6463
logger.info(f"[*] Load configuration file at {config_path}")
6564

6665
update_visual_prompting_config(config, otx_config)
6766

68-
# update model_checkpoint
69-
if model_checkpoint:
70-
config.model.checkpoint = model_checkpoint
67+
if mode == "train":
68+
# update model_checkpoint
69+
if model_checkpoint:
70+
config.model.checkpoint = model_checkpoint
71+
72+
# update resume_from_checkpoint
73+
config.trainer.resume_from_checkpoint = resume_from_checkpoint
7174

72-
# update resume_from_checkpoint
73-
config.trainer.resume_from_checkpoint = resume_from_checkpoint
75+
save_path = Path(os.path.join(config_dir, "config.yaml"))
76+
save_path.write_text(OmegaConf.to_yaml(config))
77+
logger.info(f"[*] Save updated configuration file at {str(save_path)}")
7478

75-
save_path = Path(os.path.join(output_path, f"{config_filename}.{config_file_extension}"))
76-
save_path.write_text(OmegaConf.to_yaml(config))
77-
logger.info(f"[*] Save configuration file at {str(save_path)}")
7879
return config
7980

8081

@@ -95,9 +96,9 @@ def update_visual_prompting_config(
9596
groups = getattr(otx_config, "groups", None)
9697
if groups:
9798
for group in groups:
98-
if group in ["learning_parameters", "nncf_optimization", "pot_parameters"]:
99+
if group in ["learning_parameters", "nncf_optimization", "pot_parameters", "postprocessing"]:
99100
if group in ["nncf_optimization", "pot_parameters"]:
100-
# TODO (sungchul): Consider pot_parameters and nncf_optimization
101+
# TODO (sungchul): Consider pot_parameters, nncf_optimization, and postprocessing
101102
logger.warning(f"{group} will be implemented.")
102103
continue
103104
update_visual_prompting_config(visual_prompting_config, getattr(otx_config, group))

src/otx/algorithms/visual_prompting/adapters/pytorch_lightning/datasets/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# Copyright (C) 2023 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55

6-
from .dataset import OTXVisualPromptingDataModule
7-
from .pipelines import MultipleInputsCompose, Pad, ResizeLongestSide
8-
9-
__all__ = ["OTXVisualPromptingDataModule", "ResizeLongestSide", "MultipleInputsCompose", "Pad"]
6+
from .dataset import (
7+
OTXVisualPromptingDataModule, # noqa: F401
8+
OTXVisualPromptingDataset, # noqa: F401
9+
get_transform, # noqa: F401
10+
)
11+
from .pipelines import MultipleInputsCompose, Pad, ResizeLongestSide # noqa: F401

0 commit comments

Comments
 (0)