@@ -254,52 +254,68 @@ def __str__(self) -> str:
254254
255255 """
256256
257- def dataset_info (loader : DataLoader ) -> tuple [str , str ]:
257+ class dataset_info :
258+ def __init__ (self , available : str , length : str ) -> None :
259+ self .available = available
260+ self .length = length
261+
262+ def retrieve_dataset_info (loader : DataLoader ) -> dataset_info :
258263 """Helper function to compute dataset information."""
259264 dataset = loader .dataset
260265 size : str = str (len (dataset )) if isinstance (dataset , Sized ) else "unknown"
266+ output = dataset_info ("yes" , size )
267+ return output
261268
262- return "yes" , size
263-
264- def loader_info ( loader_instance : Union [DataLoader , Iterable [DataLoader ]]) -> str :
269+ def loader_info (
270+ loader_instance : Union [ DataLoader , Iterable [ DataLoader ]],
271+ ) -> Union [dataset_info , Iterable [dataset_info ]] :
265272 """Helper function to compute dataset information."""
266- result = apply_to_collection (loader_instance , DataLoader , dataset_info )
273+ result = apply_to_collection (loader_instance , DataLoader , retrieve_dataset_info )
267274
268275 return result
269276
277+ def extract_loader_info (methods : list [tuple [str , str ]]) -> dict :
278+ """Helper function to extract information for each dataloader method."""
279+ info : dict [str , Union [dataset_info , Iterable [dataset_info ]]] = {}
280+ for method_str , function_name in methods :
281+ loader_method = getattr (self , function_name , None )
282+
283+ try :
284+ loader_instance = loader_method ()
285+ info [method_str ] = loader_info (loader_instance )
286+ except Exception :
287+ info [method_str ] = dataset_info ("no" , "unknown" )
288+
289+ return info
290+
291+ def format_loader_info (info : dict [str , Union [dataset_info , Iterable [dataset_info ]]]) -> str :
292+ """Helper function to format loader information."""
293+ lines = []
294+ for method_str , method_info in info .items ():
295+ # Single dataset
296+ if isinstance (method_info , dataset_info ):
297+ data_info = f"{{{ method_str } : available={ method_info .available } , size={ method_info .length } }}"
298+ lines .append (data_info )
299+ # Iterable of datasets
300+ else :
301+ itr_data_info = " ; " .join (
302+ f"{ i } . available={ dataset .available } , size={ dataset .length } "
303+ for i , dataset in enumerate (method_info , start = 1 )
304+ )
305+ lines .append (f"{{{ method_str } : { itr_data_info } }}" )
306+
307+ return os .linesep .join (lines )
308+
309+ # Available dataloader methods
270310 dataloader_methods : list [tuple [str , str ]] = [
271311 ("Train dataset" , "train_dataloader" ),
272312 ("Validation dataset" , "val_dataloader" ),
273313 ("Test dataset" , "test_dataloader" ),
274314 ("Prediction dataset" , "predict_dataloader" ),
275315 ]
276- dataloader_info : dict [str , Union [tuple [str , str ], Iterable [tuple [str , str ]]]] = {}
277316
278317 # Retrieve information for each dataloader method
279- for method_pair in dataloader_methods :
280- method_str , method_name = method_pair
281- loader_method = getattr (self , method_name , None )
282-
283- try :
284- loader_instance = loader_method ()
285- dataloader_info [method_str ] = loader_info (loader_instance )
286- except Exception :
287- dataloader_info [method_str ] = ("no" , "unknown" )
288-
318+ dataloader_info = extract_loader_info (dataloader_methods )
289319 # Format the information
290- dataloader_str : str = ""
291- for method_str , method_info in dataloader_info .items ():
292- # Single data set
293- if isinstance (method_info , tuple ):
294- dataloader_str += f"{{{ method_str } : "
295- dataloader_str += f"available={ method_info [0 ]} , size={ method_info [1 ]} "
296- dataloader_str += f"}}{ os .linesep } "
297- else :
298- # Iterable of datasets
299- dataloader_str += f"{{{ method_str } : "
300- for i , info in enumerate (method_info , start = 1 ):
301- dataloader_str += f"{ i } . available={ info [0 ]} , size={ info [1 ]} ; "
302- dataloader_str = dataloader_str [:- 3 ]
303- dataloader_str += f"}}{ os .linesep } "
304-
320+ dataloader_str = format_loader_info (dataloader_info )
305321 return dataloader_str
0 commit comments