@@ -263,59 +263,57 @@ def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
263263 """Helper function to compute dataset information."""
264264 dataset = loader .dataset
265265 size : str = str (len (dataset )) if isinstance (dataset , Sized ) else "unknown"
266- output = dataset_info ( "yes" , size )
267- return output
266+
267+ return dataset_info ( "yes" , size )
268268
269269 def loader_info (
270- loader_instance : Union [DataLoader , Iterable [DataLoader ]],
270+ loader : Union [DataLoader , Iterable [DataLoader ]],
271271 ) -> Union [dataset_info , Iterable [dataset_info ]]:
272272 """Helper function to compute dataset information."""
273- result = apply_to_collection (loader_instance , DataLoader , retrieve_dataset_info )
274-
275- return result
273+ return apply_to_collection (loader , DataLoader , retrieve_dataset_info )
276274
277275 def extract_loader_info (methods : list [tuple [str , str ]]) -> dict :
278276 """Helper function to extract information for each dataloader method."""
279277 info : dict [str , Union [dataset_info , Iterable [dataset_info ]]] = {}
280- for method_str , function_name in methods :
281- loader_method = getattr (self , function_name , None )
278+ for loader_name , func_name in methods :
279+ loader_callback = getattr (self , func_name , None )
282280
283281 try :
284- loader_instance = loader_method ()
285- info [method_str ] = loader_info (loader_instance )
282+ loader = loader_callback ()
283+ info [loader_name ] = loader_info (loader )
286284 except Exception :
287- info [method_str ] = dataset_info ("no" , "unknown" )
285+ info [loader_name ] = dataset_info ("no" , "unknown" )
288286
289287 return info
290288
291289 def format_loader_info (info : dict [str , Union [dataset_info , Iterable [dataset_info ]]]) -> str :
292290 """Helper function to format loader information."""
293- lines = []
294- for method_str , method_info in info .items ():
291+ output = []
292+ for loader_name , loader_info in info .items ():
295293 # Single dataset
296- if isinstance (method_info , dataset_info ):
297- data_info = f"{{{ method_str } : available={ method_info .available } , size={ method_info .length } }}"
298- lines .append (data_info )
294+ if isinstance (loader_info , dataset_info ):
295+ loader_info_formatted = f"available={ loader_info .available } , size={ loader_info .length } "
299296 # Iterable of datasets
300297 else :
301- itr_data_info = " ; " .join (
302- f"{ i } . available={ dataset .available } , size={ dataset .length } "
303- for i , dataset in enumerate (method_info , start = 1 )
298+ loader_info_formatted = " ; " .join (
299+ f"{ i } . available={ loader_info_i .available } , size={ loader_info_i .length } "
300+ for i , loader_info_i in enumerate (loader_info , start = 1 )
304301 )
305- lines .append (f"{{{ method_str } : { itr_data_info } }}" )
306302
307- return os .linesep .join (lines )
303+ output .append (f"{{{ loader_name } : { loader_info_formatted } }}" )
304+
305+ return os .linesep .join (output )
308306
309307 # Available dataloader methods
310- dataloader_methods : list [tuple [str , str ]] = [
308+ datamodule_loader_methods : list [tuple [str , str ]] = [
311309 ("Train dataset" , "train_dataloader" ),
312310 ("Validation dataset" , "val_dataloader" ),
313311 ("Test dataset" , "test_dataloader" ),
314312 ("Prediction dataset" , "predict_dataloader" ),
315313 ]
316314
317315 # Retrieve information for each dataloader method
318- dataloader_info = extract_loader_info (dataloader_methods )
316+ dataloader_info = extract_loader_info (datamodule_loader_methods )
319317 # Format the information
320318 dataloader_str = format_loader_info (dataloader_info )
321319 return dataloader_str
0 commit comments