Skip to content

Commit b3fef7a

Browse files
authored
Update export and nncf hyperparameters (#2306)
1 parent 4fc2e8d commit b3fef7a

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

otx/algorithms/detection/adapters/mmdet/nncf/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,6 @@ def _optimize_post_hook(
113113
def _save_model_post_hook(self, modelinfo):
114114
if self._recipe_cfg is not None and should_cluster_anchors(self._recipe_cfg):
115115
modelinfo["anchors"] = {}
116-
self._update_anchors(modelinfo["anchors"], self._recipe_cfg.model.bbox_head.anchor_generator)
116+
self._update_anchors(modelinfo["anchors"], self.config.model.bbox_head.anchor_generator)
117117

118118
modelinfo["confidence_threshold"] = self.confidence_threshold

otx/algorithms/detection/adapters/mmdet/utils/config_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def patch_input_shape(cfg: ConfigDict, deploy_cfg: ConfigDict):
458458
w, h = size
459459
logger.info(f"Patching OpenVINO IR input shape: {size}")
460460
deploy_cfg.ir_config.input_shape = (w, h)
461-
deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[1, 3, h, w]))]
461+
deploy_cfg.backend_config.model_inputs = [ConfigDict(opt_shapes=ConfigDict(input=[-1, 3, h, w]))]
462462

463463

464464
def patch_ir_scale_factor(deploy_cfg: ConfigDict, hyper_parameters: DetectionConfig):

tests/unit/algorithms/detection/adapters/mmdet/nncf/test_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_save_model(self, mocker):
4949
}
5050
}
5151
)
52+
self.det_nncf_task.config = self.det_nncf_task._recipe_cfg
5253
self.det_nncf_task.save_model(self.model)
5354

5455
assert self.model.get_data("weights.pth")

0 commit comments

Comments
 (0)