88import  os 
99
1010import  torch 
11- TORCH_MAJOR  =  int (torch .__version__ .split ('.' )[0 ])
12- TORCH_MINOR  =  int (torch .__version__ .split ('.' )[1 ])
11+ 
12+ TORCH_MAJOR  =  int (torch .__version__ .split ("." )[0 ])
13+ TORCH_MINOR  =  int (torch .__version__ .split ("." )[1 ])
1314
1415if  TORCH_MAJOR  ==  1  and  TORCH_MINOR  <  8 :
1516    from  torch ._six  import  string_classes 
1617else :
1718    string_classes  =  str 
1819
19- from  collections  import  Mapping 
20+ from  collections . abc  import  Mapping 
2021
2122from  fastreid .config  import  configurable 
2223from  fastreid .utils  import  comm 
2627from  .datasets  import  DATASET_REGISTRY 
2728from  .transforms  import  build_transforms 
2829
29- __all__  =  [
30-     "build_reid_train_loader" ,
31-     "build_reid_test_loader" 
32- ]
30+ __all__  =  ["build_reid_train_loader" , "build_reid_test_loader" ]
3331
3432_root  =  os .getenv ("FASTREID_DATASETS" , "datasets" )
3533
3634
37- def  _train_loader_from_config (cfg , * , train_set = None , transforms = None , sampler = None , ** kwargs ):
35+ def  _train_loader_from_config (
36+     cfg , * , train_set = None , transforms = None , sampler = None , ** kwargs 
37+ ):
3838    if  transforms  is  None :
3939        transforms  =  build_transforms (cfg , is_train = True )
4040
@@ -58,12 +58,18 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
5858        if  sampler_name  ==  "TrainingSampler" :
5959            sampler  =  samplers .TrainingSampler (len (train_set ))
6060        elif  sampler_name  ==  "NaiveIdentitySampler" :
61-             sampler  =  samplers .NaiveIdentitySampler (train_set .img_items , mini_batch_size , num_instance )
61+             sampler  =  samplers .NaiveIdentitySampler (
62+                 train_set .img_items , mini_batch_size , num_instance 
63+             )
6264        elif  sampler_name  ==  "BalancedIdentitySampler" :
63-             sampler  =  samplers .BalancedIdentitySampler (train_set .img_items , mini_batch_size , num_instance )
65+             sampler  =  samplers .BalancedIdentitySampler (
66+                 train_set .img_items , mini_batch_size , num_instance 
67+             )
6468        elif  sampler_name  ==  "SetReWeightSampler" :
6569            set_weight  =  cfg .DATALOADER .SET_WEIGHT 
66-             sampler  =  samplers .SetReWeightSampler (train_set .img_items , mini_batch_size , num_instance , set_weight )
70+             sampler  =  samplers .SetReWeightSampler (
71+                 train_set .img_items , mini_batch_size , num_instance , set_weight 
72+             )
6773        elif  sampler_name  ==  "ImbalancedDatasetSampler" :
6874            sampler  =  samplers .ImbalancedDatasetSampler (train_set .img_items )
6975        else :
@@ -79,7 +85,11 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
7985
8086@configurable (from_config = _train_loader_from_config ) 
8187def  build_reid_train_loader (
82-         train_set , * , sampler = None , total_batch_size , num_workers = 0 ,
88+     train_set ,
89+     * ,
90+     sampler = None ,
91+     total_batch_size ,
92+     num_workers = 0 ,
8393):
8494    """ 
8595    Build a dataloader for object re-identification with some default features. 
@@ -91,7 +101,9 @@ def build_reid_train_loader(
91101
92102    mini_batch_size  =  total_batch_size  //  comm .get_world_size ()
93103
94-     batch_sampler  =  torch .utils .data .sampler .BatchSampler (sampler , mini_batch_size , True )
104+     batch_sampler  =  torch .utils .data .sampler .BatchSampler (
105+         sampler , mini_batch_size , True 
106+     )
95107
96108    train_loader  =  DataLoaderX (
97109        comm .get_local_rank (),
@@ -105,12 +117,16 @@ def build_reid_train_loader(
105117    return  train_loader 
106118
107119
108- def  _test_loader_from_config (cfg , * , dataset_name = None , test_set = None , num_query = 0 , transforms = None , ** kwargs ):
120+ def  _test_loader_from_config (
121+     cfg , * , dataset_name = None , test_set = None , num_query = 0 , transforms = None , ** kwargs 
122+ ):
109123    if  transforms  is  None :
110124        transforms  =  build_transforms (cfg , is_train = False )
111125
112126    if  test_set  is  None :
113-         assert  dataset_name  is  not   None , "dataset_name must be explicitly passed in when test_set is not provided" 
127+         assert  (
128+             dataset_name  is  not   None 
129+         ), "dataset_name must be explicitly passed in when test_set is not provided" 
114130        data  =  DATASET_REGISTRY .get (dataset_name )(root = _root , ** kwargs )
115131        if  comm .is_main_process ():
116132            data .show_test ()
@@ -184,7 +200,9 @@ def fast_batch_collator(batched_inputs):
184200        return  out 
185201
186202    elif  isinstance (elem , Mapping ):
187-         return  {key : fast_batch_collator ([d [key ] for  d  in  batched_inputs ]) for  key  in  elem }
203+         return  {
204+             key : fast_batch_collator ([d [key ] for  d  in  batched_inputs ]) for  key  in  elem 
205+         }
188206
189207    elif  isinstance (elem , float ):
190208        return  torch .tensor (batched_inputs , dtype = torch .float64 )
0 commit comments