Skip to content

Commit 75c5c5f

Browse files
authored
Fix - Fix imports for python 3.10 (#1)
**💬 Feature description:** Similar to JDAI-CV#713 **🧪 How to test:** Tested in https://github.gsissc.myatos.net/GLB-BDS-AILAB-CV/innov-tools-tracklet_annotations/pull/30
1 parent c9bc3ce commit 75c5c5f

File tree

2 files changed

+37
-18
lines changed

2 files changed

+37
-18
lines changed

fastreid/data/build.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
import os
99

1010
import torch
11-
TORCH_MAJOR = int(torch.__version__.split('.')[0])
12-
TORCH_MINOR = int(torch.__version__.split('.')[1])
11+
12+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
13+
TORCH_MINOR = int(torch.__version__.split(".")[1])
1314

1415
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
1516
from torch._six import string_classes
1617
else:
1718
string_classes = str
1819

19-
from collections import Mapping
20+
from collections.abc import Mapping
2021

2122
from fastreid.config import configurable
2223
from fastreid.utils import comm
@@ -26,15 +27,14 @@
2627
from .datasets import DATASET_REGISTRY
2728
from .transforms import build_transforms
2829

29-
__all__ = [
30-
"build_reid_train_loader",
31-
"build_reid_test_loader"
32-
]
30+
__all__ = ["build_reid_train_loader", "build_reid_test_loader"]
3331

3432
_root = os.getenv("FASTREID_DATASETS", "datasets")
3533

3634

37-
def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=None, **kwargs):
35+
def _train_loader_from_config(
36+
cfg, *, train_set=None, transforms=None, sampler=None, **kwargs
37+
):
3838
if transforms is None:
3939
transforms = build_transforms(cfg, is_train=True)
4040

@@ -58,12 +58,18 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
5858
if sampler_name == "TrainingSampler":
5959
sampler = samplers.TrainingSampler(len(train_set))
6060
elif sampler_name == "NaiveIdentitySampler":
61-
sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
61+
sampler = samplers.NaiveIdentitySampler(
62+
train_set.img_items, mini_batch_size, num_instance
63+
)
6264
elif sampler_name == "BalancedIdentitySampler":
63-
sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance)
65+
sampler = samplers.BalancedIdentitySampler(
66+
train_set.img_items, mini_batch_size, num_instance
67+
)
6468
elif sampler_name == "SetReWeightSampler":
6569
set_weight = cfg.DATALOADER.SET_WEIGHT
66-
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
70+
sampler = samplers.SetReWeightSampler(
71+
train_set.img_items, mini_batch_size, num_instance, set_weight
72+
)
6773
elif sampler_name == "ImbalancedDatasetSampler":
6874
sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
6975
else:
@@ -79,7 +85,11 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
7985

8086
@configurable(from_config=_train_loader_from_config)
8187
def build_reid_train_loader(
82-
train_set, *, sampler=None, total_batch_size, num_workers=0,
88+
train_set,
89+
*,
90+
sampler=None,
91+
total_batch_size,
92+
num_workers=0,
8393
):
8494
"""
8595
Build a dataloader for object re-identification with some default features.
@@ -91,7 +101,9 @@ def build_reid_train_loader(
91101

92102
mini_batch_size = total_batch_size // comm.get_world_size()
93103

94-
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, mini_batch_size, True)
104+
batch_sampler = torch.utils.data.sampler.BatchSampler(
105+
sampler, mini_batch_size, True
106+
)
95107

96108
train_loader = DataLoaderX(
97109
comm.get_local_rank(),
@@ -105,12 +117,16 @@ def build_reid_train_loader(
105117
return train_loader
106118

107119

108-
def _test_loader_from_config(cfg, *, dataset_name=None, test_set=None, num_query=0, transforms=None, **kwargs):
120+
def _test_loader_from_config(
121+
cfg, *, dataset_name=None, test_set=None, num_query=0, transforms=None, **kwargs
122+
):
109123
if transforms is None:
110124
transforms = build_transforms(cfg, is_train=False)
111125

112126
if test_set is None:
113-
assert dataset_name is not None, "dataset_name must be explicitly passed in when test_set is not provided"
127+
assert (
128+
dataset_name is not None
129+
), "dataset_name must be explicitly passed in when test_set is not provided"
114130
data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
115131
if comm.is_main_process():
116132
data.show_test()
@@ -184,7 +200,9 @@ def fast_batch_collator(batched_inputs):
184200
return out
185201

186202
elif isinstance(elem, Mapping):
187-
return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
203+
return {
204+
key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem
205+
}
188206

189207
elif isinstance(elem, float):
190208
return torch.tensor(batched_inputs, dtype=torch.float64)

fastreid/evaluation/testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
import pprint
44
import sys
5-
from collections import Mapping, OrderedDict
5+
from collections import OrderedDict
6+
from collections.abc import Mapping
67

78
import numpy as np
89
from tabulate import tabulate
@@ -20,7 +21,7 @@ def print_csv_format(results):
2021
assert isinstance(results, OrderedDict) or not len(results), results
2122
logger = logging.getLogger(__name__)
2223

23-
dataset_name = results.pop('dataset')
24+
dataset_name = results.pop("dataset")
2425
metrics = ["Dataset"] + [k for k in results]
2526
csv_results = [(dataset_name, *list(results.values()))]
2627

0 commit comments

Comments
 (0)