Skip to content

Commit 6ad365c

Browse files
committed
Add usps dataloader as alternative
1 parent 6dfd94d commit 6ad365c

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

utils/load_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from dataloaders import USPS_0_6
12
from torch.utils.data import Dataset
23

34

45
def load_data(dataset: str) -> Dataset:
5-
raise ValueError(
6-
f"Dataset: {dataset} not implemented. \nCheck the documentation for implemented metrics, or check your spelling"
7-
)
6+
match dataset.lower():
7+
case "usps_0-6":
8+
return USPS_0_6
9+
case _:
10+
raise ValueError(f"Dataset: {dataset} not implemented.")

0 commit comments

Comments
 (0)