Skip to content

Commit de75276

Browse files
authored
Fix export issue in Geti (#2316)
1 parent 7027132 commit de75276

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

src/otx/algorithms/classification/adapters/mmcls/task.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,18 @@ def hook(module, inp, outp): # pylint: disable=unused-argument
549549
return eval_predictions, saliency_maps
550550

551551
def _export_model(self, precision: ModelPrecision, export_format: ExportType, dump_features: bool):
552-
self._data_cfg = None
552+
self._data_cfg = ConfigDict(
553+
data=ConfigDict(
554+
train=ConfigDict(
555+
otx_dataset=None,
556+
labels=self._labels,
557+
),
558+
test=ConfigDict(
559+
otx_dataset=None,
560+
labels=self._labels,
561+
),
562+
)
563+
)
553564
self._init_task(export=True)
554565

555566
cfg = self.configure(False, "test", None)

src/otx/algorithms/detection/adapters/mmdet/task.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,18 @@ def _export_model(
491491
dump_features: bool,
492492
):
493493
"""Main export function of OTX MMDetection Task."""
494-
self._data_cfg = None
494+
self._data_cfg = ConfigDict(
495+
data=ConfigDict(
496+
train=ConfigDict(
497+
otx_dataset=None,
498+
labels=self._labels,
499+
),
500+
test=ConfigDict(
501+
otx_dataset=None,
502+
labels=self._labels,
503+
),
504+
)
505+
)
495506
self._init_task(export=True)
496507

497508
cfg = self.configure(False, "test", None)

src/otx/algorithms/segmentation/adapters/mmseg/task.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,18 @@ def _export_model(
423423
):
424424
"""Export function of OTX Segmentation Task."""
425425
# copied from OTX inference_task.py
426-
self._data_cfg = None
426+
self._data_cfg = ConfigDict(
427+
data=ConfigDict(
428+
train=ConfigDict(
429+
otx_dataset=None,
430+
labels=self._labels,
431+
),
432+
test=ConfigDict(
433+
otx_dataset=None,
434+
labels=self._labels,
435+
),
436+
)
437+
)
427438
self._init_task(export=True)
428439

429440
cfg = self.configure(False, "test", None)

tests/unit/algorithms/classification/adapters/mmcls/test_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ class MockExporter:
112112
def __init__(self, task):
113113
self._output_path = task._output_path
114114

115-
def run(self, *args, **kwargs):
115+
def run(self, cfg, *args, **kwargs):
116+
assert cfg.model.head.num_classes == 2
116117
with open(os.path.join(self._output_path, "openvino.bin"), "wb") as f:
117118
f.write(np.ndarray([0]))
118119
with open(os.path.join(self._output_path, "openvino.xml"), "wb") as f:

tests/unit/algorithms/detection/adapters/mmdet/test_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ class MockExporter:
114114
def __init__(self, task):
115115
self._output_path = task._output_path
116116

117-
def run(self, *args, **kwargs):
117+
def run(self, cfg, *args, **kwargs):
118+
cfg.model.bbox_head.num_classes == 3
118119
with open(os.path.join(self._output_path, "openvino.bin"), "wb") as f:
119120
f.write(np.ndarray([0]))
120121
with open(os.path.join(self._output_path, "openvino.xml"), "wb") as f:

0 commit comments

Comments
 (0)