11import time
22import os
33import logging
4+ from typing import Union
45
56import torch
67import torch .distributed as dist
78from aperturedb import Images
89from aperturedb import PyTorchDataset
10+ from torch .utils .data .dataloader import DataLoader
11+ from torch .utils .data .dataset import Dataset
912
1013logger = logging .getLogger (__name__ )
1114
1215
1316class TestTorchDatasets ():
14- def validate_dataset (self , dataset ):
17+ def validate_dataset (self , dataset : Union [ DataLoader , Dataset ], expected_length ):
1518 start = time .time ()
1619
20+ count = 0
1721 # Iterate over dataset.
1822 for img in dataset :
1923 if len (img [0 ]) < 0 :
2024 logger .error ("Empty image?" )
2125 assert True == False
26+ count += len (img [1 ]) if isinstance (dataset , DataLoader ) else 1
27+ assert count == expected_length
2228
23- logger . info ( " \n " )
24- logger . info ( "Throughput (imgs/s):" ,
25- len (dataset ) / ( time . time () - start ) )
29+ time_taken = time . time () - start
30+ if time_taken != 0 :
31+ logger . info ( f"Throughput (imgs/s): { len (dataset ) / time_taken } " )
2632
2733 def test_omConstraints (self , db , utils , images ):
2834 assert len (images ) > 0
@@ -31,8 +37,7 @@ def test_omConstraints(self, db, utils, images):
3137 dataset = PyTorchDataset .ApertureDBDatasetConstraints (
3238 db , constraints = const )
3339
34- assert len (dataset ) == utils .count_images ()
35- self .validate_dataset (dataset )
40+ self .validate_dataset (dataset , utils .count_images ())
3641
3742 def test_nativeContraints (self , db , utils , images ):
3843 assert len (images ) > 0
@@ -57,10 +62,10 @@ def test_nativeContraints(self, db, utils, images):
5762 dataset = PyTorchDataset .ApertureDBDataset (
5863 db , query , label_prop = "license" )
5964
60- assert len (dataset ) == utils .count_images ()
61- self .validate_dataset (dataset )
65+ self .validate_dataset (dataset , utils .count_images ())
6266
6367 def test_datasetWithMultiprocessing (self , db , utils ):
68+ len_limit = utils .count_images ()
6469 query = [{
6570 "FindImage" : {
6671 "constraints" : {
@@ -74,16 +79,16 @@ def test_datasetWithMultiprocessing(self, db, utils):
7479 }
7580 ],
7681 "results" : {
77- "list" : ["license" ]
82+ "list" : ["license" ],
83+ "limit" : len_limit
7884 }
7985 }
8086 }]
8187
8288 dataset = PyTorchDataset .ApertureDBDataset (
8389 db , query , label_prop = "license" )
8490
85- assert len (dataset ) == utils .count_images ()
86- self .validate_dataset (dataset )
91+ self .validate_dataset (dataset , len_limit )
8792 # Distributed Data Loader Setup
8893
8994 # Needed for init_process_group
@@ -93,30 +98,30 @@ def test_datasetWithMultiprocessing(self, db, utils):
9398 dist .init_process_group ("gloo" , rank = 0 , world_size = 1 )
9499
95100 # === Distributed Data Loader Sequential
96-
97- data_loader = torch . utils . data . DataLoader (
101+ batch_size = 10
102+ data_loader = DataLoader (
98103 dataset ,
99- batch_size = 10 , # pick random values here to test
104+ batch_size = batch_size , # pick random values here to test
100105 num_workers = 4 , # num_workers > 1 to test multiprocessing works
101106 pin_memory = True ,
102107 drop_last = True ,
103108 )
104109
105- self .validate_dataset (data_loader )
110+ self .validate_dataset (data_loader , len_limit )
106111 # === Distributed Data Loader Shuffler
107112
108113 # This will generate a random sampler, which will make the use
109114 # of batching wasteful
110115 sampler = torch .utils .data .DistributedSampler (
111116 dataset , shuffle = True )
112117
113- data_loader = torch . utils . data . DataLoader (
118+ data_loader = DataLoader (
114119 dataset ,
115120 sampler = sampler ,
116- batch_size = 10 , # pick random values here to test
121+ batch_size = batch_size , # pick random values here to test
117122 num_workers = 4 , # num_workers > 1 to test multiprocessing works
118123 pin_memory = True ,
119124 drop_last = True ,
120125 )
121126
122- self .validate_dataset (data_loader )
127+ self .validate_dataset (data_loader , len_limit )
0 commit comments