Skip to content

Commit 927e20a

Browse files
committed
feat: Update sampler for train dataloader.
1 parent f1c861a commit 927e20a

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

vis4d/data/loader.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
import numpy as np
1010
import torch
11-
from torch.utils.data import DataLoader, Dataset
11+
from torch.utils.data import (
12+
DataLoader,
13+
Dataset,
14+
RandomSampler,
15+
SequentialSampler,
16+
)
1217
from torch.utils.data.distributed import DistributedSampler, Sampler
1318

1419
from vis4d.common.distributed import get_rank, get_world_size
@@ -167,20 +172,29 @@ def _worker_init_fn(worker_id: int) -> None:
167172
if disable_subprocess_warning and worker_id != 0:
168173
warnings.simplefilter("ignore")
169174

170-
if get_world_size() > 1 and sampler is None:
171-
sampler = DistributedSampler(
172-
dataset, shuffle=shuffle, drop_last=drop_last
173-
)
174-
shuffle = False
175-
drop_last = False
175+
if sampler is None:
176+
if get_world_size() > 1:
177+
sampler = DistributedSampler(
178+
dataset, shuffle=shuffle, drop_last=drop_last
179+
)
180+
shuffle = False
181+
drop_last = False
182+
else:
183+
if shuffle:
184+
sampler = RandomSampler(dataset)
185+
shuffle = False
186+
else:
187+
sampler = SequentialSampler(dataset)
176188

177189
batch_sampler = None
178-
if aspect_ratio_grouping and sampler is not None:
190+
if aspect_ratio_grouping:
179191
batch_sampler = AspectRatioBatchSampler(
180-
sampler, batch_size=samples_per_gpu
192+
sampler, batch_size=samples_per_gpu, drop_last=drop_last
181193
)
182194
samples_per_gpu = 1
183195
shuffle = None
196+
drop_last = False
197+
sampler = None
184198

185199
dataloader = DataLoader(
186200
dataset,
@@ -189,7 +203,7 @@ def _worker_init_fn(worker_id: int) -> None:
189203
collate_fn=(
190204
_collate_fn_multi if dataset.has_reference else _collate_fn_single
191205
),
192-
sampler=sampler if not aspect_ratio_grouping else None,
206+
sampler=sampler,
193207
batch_sampler=batch_sampler,
194208
worker_init_fn=_worker_init_fn,
195209
persistent_workers=workers_per_gpu > 0,

vis4d/data/samplers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,14 @@ def __init__(
115115

116116
def __iter__(self):
117117
for idx in self.sampler:
118-
data_dict = self.sampler.dataset[idx]
118+
if hasattr(self.sampler, "dataset"):
119+
data_dict = self.sampler.dataset[idx]
120+
elif hasattr(self.sampler, "data_source"):
121+
data_dict = self.sampler.data_source[idx]
122+
else:
123+
raise ValueError(
124+
"sampler should have dataset or data_source attribute"
125+
)
119126
height, width = data_dict[K.input_hw]
120127
bucket_id = 0 if width < height else 1
121128
bucket = self._aspect_ratio_buckets[bucket_id]

vis4d/engine/callbacks/visualizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def on_test_batch_end(
137137
trainer.logger.log_image(
138138
key=f"{self.visualizer}/{cur_iter}",
139139
images=[image],
140-
step=trainer.global_step,
141140
)
142141
synchronize()
143142

0 commit comments

Comments
 (0)