Skip to content

Commit e1b5124

Browse files
committed
Minor fixes
1 parent 96d9c72 commit e1b5124

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

scripts/cache_to_ply.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def cache_to_ply(
5252
help="The directory where the cache should be created",
5353
show_default=False,
5454
),
55-
sequence: Optional[str] = typer.Option(
55+
sequence: Optional[List[str]] = typer.Option(
5656
None,
5757
"--sequence",
5858
"-s",
@@ -76,13 +76,14 @@ def cache_to_ply(
7676

7777
# Run
7878
cfg = load_config(config)
79+
sequences = sequence if sequence != None else cfg.training.train + cfg.training.val
7980

8081
data_iterable = DataLoader(
8182
MOS4DDataset(
8283
dataloader=dataloader,
8384
data_dir=data,
8485
config=cfg,
85-
sequences=[sequence],
86+
sequences=sequences,
8687
cache_dir=cache_dir,
8788
),
8889
batch_size=1,

scripts/precache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def precache(
4646
help="The directory where the cache should be created",
4747
show_default=False,
4848
),
49-
sequence: List[str] = typer.Option(
49+
sequence: Optional[List[str]] = typer.Option(
5050
None,
5151
"--sequence",
5252
"-s",
@@ -68,7 +68,7 @@ def precache(
6868
from mos4d.datasets.mos4d_dataset import collate_fn
6969

7070
cfg = load_config(config)
71-
sequences = list(sequence) if sequence != None else cfg.training.train + cfg.training.val
71+
sequences = sequence if sequence != None else cfg.training.train + cfg.training.val
7272

7373
data_iterable = DataLoader(
7474
Dataset(

src/mos4d/datasets/kitti.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,21 @@ def __init__(self, data_dir, sequence: str, *_, **__):
5151
self.label_dir = os.path.join(self.kitti_sequence_dir, "labels/")
5252
self.label_files = sorted(glob.glob(self.label_dir + "*.label"))
5353

54+
# Account for incomplete label files
55+
label_map = {os.path.basename(path): path for path in self.label_files}
56+
if len(self.label_files) != len(self.scan_files):
57+
self.label_files = [
58+
label_map.get(os.path.basename(scan_file).replace(".bin", ".label"), None)
59+
for scan_file in self.scan_files
60+
]
61+
5462
def __getitem__(self, idx):
5563
points, timestamps = self.scans(idx)
5664
labels = (
5765
self.read_labels(self.label_files[idx])
58-
if self.label_files
66+
if (self.label_files and self.label_files[idx] != None)
5967
else np.full((len(points), 1), -1, dtype=np.int32)
60-
)
68+
).reshape(-1)
6169
return points, timestamps, labels
6270

6371
def read_labels(self, filename):

0 commit comments

Comments
 (0)