Skip to content

Commit 794a814

Browse files
author
Songki Choi
authored
Fix division by zero in class incremental learning for classification (#2606)
* Add empty label to reproduce zero-division error Signed-off-by: Songki Choi <[email protected]> * Fix minor typo Signed-off-by: Songki Choi <[email protected]> * Fix empty label 4 -> 3 Signed-off-by: Songki Choi <[email protected]> * Prevent division by zero Signed-off-by: Songki Choi <[email protected]> * Update license Signed-off-by: Songki Choi <[email protected]> * Update CHANGELOG.md Signed-off-by: Songki Choi <[email protected]> * Fix inefficient sampling Signed-off-by: Songki Choi <[email protected]> * Revert indexing Signed-off-by: Songki Choi <[email protected]> * Fix minor typo Signed-off-by: Songki Choi <[email protected]> --------- Signed-off-by: Songki Choi <[email protected]>
1 parent 3ec4c95 commit 794a814

File tree

7 files changed

+25
-16
lines changed

7 files changed

+25
-16
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ All notable changes to this project will be documented in this file.
1515
- Fix IBLoss enablement with DeiT-Tiny when class incremental training (<https://github.com/openvinotoolkit/training_extensions/pull/2595>)
1616
- Fix mmcls bug not wrapping model in DataParallel on CPUs (<https://github.com/openvinotoolkit/training_extensions/pull/2601>)
1717
- Fix h-label loss normalization issue w/ exclusive label group of singe label (<https://github.com/openvinotoolkit/training_extensions/pull/2604>)
18+
- Fix division by zero in class incremental learning for classification (<https://github.com/openvinotoolkit/training_extensions/pull/2606>)
1819

1920
## \[v1.4.3\]
2021

src/otx/algorithms/classification/adapters/mmcls/configurer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def _configure_dataloader(cfg):
574574
CLASS_INC_DATASET = [
575575
"OTXClsDataset",
576576
"OTXMultilabelClsDataset",
577-
"MPAHierarchicalClsDataset",
577+
"OTXHierarchicalClsDataset",
578578
"ClsTVDataset",
579579
]
580580
WEIGHT_MIX_CLASSIFIER = ["SAMImageClassifier"]

src/otx/algorithms/classification/adapters/mmcls/datasets/otx_datasets.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Base Dataset for Classification Task."""
22

3-
# Copyright (C) 2022 Intel Corporation
3+
# Copyright (C) 2022-2023 Intel Corporation
44
# SPDX-License-Identifier: Apache-2.0
55
#
66

@@ -176,7 +176,10 @@ def class_accuracy(self, results, gt_labels):
176176
for i in range(self.num_classes):
177177
cls_pred = pred_label == i
178178
cls_pred = cls_pred[gt_labels == i]
179-
cls_acc = np.sum(cls_pred) / len(cls_pred)
179+
if len(cls_pred) > 0:
180+
cls_acc = np.sum(cls_pred) / len(cls_pred)
181+
else:
182+
cls_acc = 0.0
180183
accracies.append(cls_acc)
181184
return accracies
182185

src/otx/algorithms/classification/adapters/mmcls/models/losses/ib_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Module for defining IB Loss which alleviate effect of imbalanced dataset."""
2-
# Copyright (C) 2022 Intel Corporation
2+
# Copyright (C) 2022-2023 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

@@ -48,7 +48,7 @@ def update_weight(self, cls_num_list):
4848
"""Update loss weight per class."""
4949
if len(cls_num_list) == 0:
5050
raise ValueError("Cannot compute the IB loss weight with empty cls_num_list.")
51-
per_cls_weights = 1.0 / np.array(cls_num_list)
51+
per_cls_weights = 1.0 / (np.array(cls_num_list) + self.epsilon)
5252
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
5353
per_cls_weights = torch.FloatTensor(per_cls_weights)
5454
self.weight.data = per_cls_weights.to(device=self.weight.device)

src/otx/algorithms/classification/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def _generate_training_metrics(self, learning_curves): # pylint: disable=argume
495495
elif self._hierarchical:
496496
metric_key = "val/MHAcc"
497497
else:
498-
metric_key = "val/accuracy_top-1"
498+
metric_key = "val/accuracy (%)"
499499

500500
# Learning curves
501501
best_acc = -1

src/otx/algorithms/common/adapters/torch/dataloaders/samplers/balanced_sampler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
"""Balanced sampler for imbalanced data."""
2+
# Copyright (C) 2023 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
25
import math
36

47
import numpy as np
@@ -32,24 +35,22 @@ def __init__(self, dataset, batch_size, efficient_mode=True, num_replicas=1, ran
3235
self.dataset = dataset.dataset
3336
else:
3437
self.dataset = dataset
35-
self.img_indices = self.dataset.img_indices
38+
self.img_indices = {k: v for k, v in self.dataset.img_indices.items() if len(v) > 0}
3639
self.num_cls = len(self.img_indices.keys())
3740
self.data_length = len(self.dataset)
3841
self.num_replicas = num_replicas
3942
self.rank = rank
4043
self.drop_last = drop_last
4144

45+
self.num_trials = int(self.data_length / self.num_cls)
4246
if efficient_mode:
4347
# Reduce the # of sampling (sampling data for a single epoch)
44-
self.num_tail = min(len(cls_indices) for cls_indices in self.img_indices.values())
45-
base = 1 - (1 / self.num_tail)
46-
if base == 0:
47-
raise ValueError("Required more than one sample per class")
48-
self.num_trials = int(math.log(0.001, base))
49-
if int(self.data_length / self.num_cls) < self.num_trials:
50-
self.num_trials = int(self.data_length / self.num_cls)
51-
else:
52-
self.num_trials = int(self.data_length / self.num_cls)
48+
num_tail = min(len(cls_indices) for cls_indices in self.img_indices.values())
49+
if num_tail > 1:
50+
base = 1 - (1 / num_tail)
51+
num_reduced_trials = int(math.log(0.001, base))
52+
self.num_trials = min(num_reduced_trials, self.num_trials)
53+
5354
self.num_samples = self._calculate_num_samples()
5455

5556
logger.info(f"This sampler will select balanced samples {self.num_trials} times")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Ignore everything in this directory
2+
*
3+
# Except this file
4+
!.gitignore

0 commit comments

Comments
 (0)