Skip to content

Commit fcb3293

Browse files
author
Songki Choi
authored
[FIX][RELEASE1.0] Remove cfg dump in ckpt (#1895)
* Remove cfg dump in ckpt Signed-off-by: Songki Choi <[email protected]> * Fix pre-commit Signed-off-by: Songki Choi <[email protected]> --------- Signed-off-by: Songki Choi <[email protected]>
1 parent 3825b1b commit fcb3293

File tree

2 files changed

+15
-36
lines changed

2 files changed

+15
-36
lines changed

otx/algorithms/common/tasks/nncf_base.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import io
1919
import json
2020
import os
21-
from collections.abc import Mapping
2221
from copy import deepcopy
2322
from typing import Dict, List, Optional
2423

@@ -338,30 +337,16 @@ def save_model(self, output_model: ModelEntity):
338337
hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True))
339338
labels = {label.name: label.color.rgb_tuple for label in self._labels}
340339

341-
config = deepcopy(self._recipe_cfg)
342-
config.merge_from_dict(self._model_cfg)
343-
344-
def update(d, u): # pylint: disable=invalid-name
345-
for k, v in u.items(): # pylint: disable=invalid-name
346-
if isinstance(v, Mapping):
347-
d[k] = update(d.get(k, {}), v)
348-
else:
349-
d[k] = v
350-
return d
351-
352-
modelinfo = torch.load(self._model_ckpt, map_location=torch.device("cpu"))
353-
modelinfo = update(
354-
dict(model=modelinfo),
355-
{
356-
"meta": {
357-
"nncf_enable_compression": True,
358-
"config": config,
359-
},
360-
"config": hyperparams_str,
361-
"labels": labels,
362-
"VERSION": 1,
340+
model_ckpt = torch.load(self._model_ckpt, map_location=torch.device("cpu"))
341+
modelinfo = {
342+
"model": model_ckpt,
343+
"config": hyperparams_str,
344+
"labels": labels,
345+
"VERSION": 1,
346+
"meta": {
347+
"nncf_enable_compression": True,
363348
},
364-
)
349+
}
365350
self._save_model_post_hook(modelinfo)
366351

367352
torch.save(modelinfo, buffer)

otx/algorithms/detection/tasks/nncf.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from otx.algorithms.common.adapters.mmcv.utils import remove_from_config
2323
from otx.algorithms.common.tasks.nncf_base import NNCFBaseTask
2424
from otx.algorithms.detection.adapters.mmdet.nncf import build_nncf_detector
25+
from otx.algorithms.detection.adapters.mmdet.utils.config_utils import (
26+
should_cluster_anchors,
27+
)
2528
from otx.api.entities.datasets import DatasetEntity
2629
from otx.api.entities.inference_parameters import InferenceParameters
2730
from otx.api.entities.model import ModelEntity
@@ -111,17 +114,8 @@ def _optimize_post_hook(
111114
output_model.performance = performance
112115

113116
def _save_model_post_hook(self, modelinfo):
114-
config = modelinfo["meta"]["config"]
115-
if hasattr(config.model, "bbox_head") and hasattr(config.model.bbox_head, "anchor_generator"):
116-
if getattr(
117-
config.model.bbox_head.anchor_generator,
118-
"reclustering_anchors",
119-
False,
120-
):
121-
generator = config.model.bbox_head.anchor_generator
122-
modelinfo["anchors"] = {
123-
"heights": generator.heights,
124-
"widths": generator.widths,
125-
}
117+
if self._model_cfg is not None and should_cluster_anchors(self._model_cfg):
118+
modelinfo["anchors"] = {}
119+
self._update_anchors(modelinfo["anchors"], self._model_cfg.model.bbox_head.anchor_generator)
126120

127121
modelinfo["confidence_threshold"] = self.confidence_threshold

0 commit comments

Comments
 (0)