Skip to content

Commit 1e7e902

Browse files
feat(cls/shufflenet) use native infinite sampler (#9)
1 parent b5da0a1 commit 1e7e902

File tree

1 file changed

+2
-13
lines changed
  • official/vision/classification/shufflenet

1 file changed

+2
-13
lines changed

official/vision/classification/shufflenet/train.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,6 @@ def get_parameters(model):
112112
return groups
113113

114114

115-
def infinite_iter(loader):
116-
iterator = iter(loader)
117-
while True:
118-
try:
119-
yield next(iterator)
120-
except StopIteration:
121-
iterator = iter(loader)
122-
yield next(iterator)
123-
124-
125115
def worker(rank, world_size, args):
126116
if world_size > 1:
127117
# Initialize distributed process group
@@ -174,9 +164,9 @@ def valid_func(image, label):
174164
# Build train and valid datasets
175165
logger.info("preparing dataset..")
176166
train_dataset = data.dataset.ImageNet(args.data, train=True)
177-
train_sampler = data.RandomSampler(
167+
train_sampler = data.Infinite(data.RandomSampler(
178168
train_dataset, batch_size=args.batch_size, drop_last=True
179-
)
169+
))
180170
train_queue = data.DataLoader(
181171
train_dataset,
182172
sampler=train_sampler,
@@ -193,7 +183,6 @@ def valid_func(image, label):
193183
),
194184
num_workers=args.workers,
195185
)
196-
train_queue = infinite_iter(train_queue)
197186

198187
valid_dataset = data.dataset.ImageNet(args.data, train=False)
199188
valid_sampler = data.SequentialSampler(

0 commit comments

Comments
 (0)