11# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
2+ import random
23from typing import Iterator , List , Optional , Union
34
4- import torch
5- import random
65import numpy as np
7- from torch . utils . data import ConcatDataset
6+ import torch
87from mmengine .dataset import DefaultSampler
8+ from torch .utils .data import ConcatDataset
99
10- from sscma .registry import DATA_SANPLERS
10+ from sscma .registry import DATA_SAMPLERS
1111
1212
13- @DATA_SANPLERS .register_module ()
13+ @DATA_SAMPLERS .register_module ()
1414class SemiSampler (DefaultSampler ):
15- """
16- Sampler for scaled sampling of semi-supervised data
15+ """Sampler for scaled sampling of semi-supervised data.
1716
1817 Params:
1918 dataset (torch::ConcatDataset): Multiple merged datasets
@@ -38,7 +37,7 @@ def __init__(
3837 ) -> None :
3938 assert len (sample_ratio ) == len (
4039 dataset .cumulative_sizes
41- ), " Sampling rate length must be equal to the number of datasets."
40+ ), ' Sampling rate length must be equal to the number of datasets.'
4241
4342 super (SemiSampler , self ).__init__ (dataset , shuffle = shuffle , seed = seed , round_up = round_up )
4443 if seed is not None :
@@ -65,7 +64,7 @@ def __init__(
6564 self .computer_epoch ()
6665
6766 def __iter__ (self ) -> Iterator [int ]:
68- indexs = []
67+ indexes = []
6968 num1 = 0
7069 num2 = 0
7170 data1_len = len (self .data1 )
@@ -77,21 +76,21 @@ def __iter__(self) -> Iterator[int]:
7776 for _ in range (self .total_epoch ):
7877 if self .all_data :
7978 for _ in range (self .sample_size [0 ]):
80- indexs .append (self .data1 [num1 % data1_len ])
79+ indexes .append (self .data1 [num1 % data1_len ])
8180 num1 += 1
8281 for _ in range (self .sample_size [1 ]):
83- indexs .append (self .data2 [num2 % data2_len ])
82+ indexes .append (self .data2 [num2 % data2_len ])
8483 num2 += 1
8584 elif self .only_label :
8685 for _ in range (self .sample_size [0 ] + self .sample_size [1 ]):
87- indexs .append (self .data1 [num1 % data1_len ])
86+ indexes .append (self .data1 [num1 % data1_len ])
8887 num1 += 1
8988 else :
9089 for _ in range (self .sample_size [0 ] + self .sample_size [1 ]):
91- indexs .append (self .data2 [num2 % data2_len ])
90+ indexes .append (self .data2 [num2 % data2_len ])
9291 num2 += 1
9392
94- return iter (indexs )
93+ return iter (indexes )
9594
9695 def __len__ (self ) -> int :
9796 return self .total_epoch * self ._batch_size
0 commit comments