Skip to content

Commit 330a88c

Browse files
Finilized required adjustments for dataloader string proposal method
Renamed varaibles to more sensible names to increase readability
1 parent 21029d2 commit 330a88c

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

src/lightning/pytorch/core/datamodule.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -263,59 +263,57 @@ def retrieve_dataset_info(loader: DataLoader) -> dataset_info:
263263
"""Helper function to compute dataset information."""
264264
dataset = loader.dataset
265265
size: str = str(len(dataset)) if isinstance(dataset, Sized) else "unknown"
266-
output = dataset_info("yes", size)
267-
return output
266+
267+
return dataset_info("yes", size)
268268

269269
def loader_info(
270-
loader_instance: Union[DataLoader, Iterable[DataLoader]],
270+
loader: Union[DataLoader, Iterable[DataLoader]],
271271
) -> Union[dataset_info, Iterable[dataset_info]]:
272272
"""Helper function to compute dataset information."""
273-
result = apply_to_collection(loader_instance, DataLoader, retrieve_dataset_info)
274-
275-
return result
273+
return apply_to_collection(loader, DataLoader, retrieve_dataset_info)
276274

277275
def extract_loader_info(methods: list[tuple[str, str]]) -> dict:
278276
"""Helper function to extract information for each dataloader method."""
279277
info: dict[str, Union[dataset_info, Iterable[dataset_info]]] = {}
280-
for method_str, function_name in methods:
281-
loader_method = getattr(self, function_name, None)
278+
for loader_name, func_name in methods:
279+
loader_callback = getattr(self, func_name, None)
282280

283281
try:
284-
loader_instance = loader_method()
285-
info[method_str] = loader_info(loader_instance)
282+
loader = loader_callback()
283+
info[loader_name] = loader_info(loader)
286284
except Exception:
287-
info[method_str] = dataset_info("no", "unknown")
285+
info[loader_name] = dataset_info("no", "unknown")
288286

289287
return info
290288

291289
def format_loader_info(info: dict[str, Union[dataset_info, Iterable[dataset_info]]]) -> str:
292290
"""Helper function to format loader information."""
293-
lines = []
294-
for method_str, method_info in info.items():
291+
output = []
292+
for loader_name, loader_info in info.items():
295293
# Single dataset
296-
if isinstance(method_info, dataset_info):
297-
data_info = f"{{{method_str}: available={method_info.available}, size={method_info.length}}}"
298-
lines.append(data_info)
294+
if isinstance(loader_info, dataset_info):
295+
loader_info_formatted = f"available={loader_info.available}, size={loader_info.length}"
299296
# Iterable of datasets
300297
else:
301-
itr_data_info = " ; ".join(
302-
f"{i}. available={dataset.available}, size={dataset.length}"
303-
for i, dataset in enumerate(method_info, start=1)
298+
loader_info_formatted = " ; ".join(
299+
f"{i}. available={loader_info_i.available}, size={loader_info_i.length}"
300+
for i, loader_info_i in enumerate(loader_info, start=1)
304301
)
305-
lines.append(f"{{{method_str}: {itr_data_info}}}")
306302

307-
return os.linesep.join(lines)
303+
output.append(f"{{{loader_name}: {loader_info_formatted}}}")
304+
305+
return os.linesep.join(output)
308306

309307
# Available dataloader methods
310-
dataloader_methods: list[tuple[str, str]] = [
308+
datamodule_loader_methods: list[tuple[str, str]] = [
311309
("Train dataset", "train_dataloader"),
312310
("Validation dataset", "val_dataloader"),
313311
("Test dataset", "test_dataloader"),
314312
("Prediction dataset", "predict_dataloader"),
315313
]
316314

317315
# Retrieve information for each dataloader method
318-
dataloader_info = extract_loader_info(dataloader_methods)
316+
dataloader_info = extract_loader_info(datamodule_loader_methods)
319317
# Format the information
320318
dataloader_str = format_loader_info(dataloader_info)
321319
return dataloader_str

0 commit comments

Comments
 (0)