Skip to content

Commit 81829a3

Browse files
authored
Fix num_trials calculation on dataset length less than num_class (#4014)
Fix balanced sampler
1 parent 7040faf commit 81829a3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/otx/algo/samplers/balanced_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0}
6666
self.num_cls = len(self.img_indices.keys())
6767
self.data_length = len(self.dataset)
68-
self.num_trials = int(self.data_length / self.num_cls)
68+
self.num_trials = max(int(self.data_length / self.num_cls), 1)
6969

7070
if efficient_mode:
7171
# Reduce the # of sampling (sampling data for a single epoch)

0 commit comments

Comments
 (0)