Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,28 @@ def load_dataset(
**config_kwargs,
)

# Get available splits
available_splits = builder_instance.info.splits.keys()

# If there are subdatasets or splits, print them and ask for user input
if available_splits:
print("Available splits or subdatasets:")
for i, split_name in enumerate(available_splits):
print(f"{i + 1}. {split_name}")

# Ask for user input to choose the split/subdataset
while True:
try:
user_choice = int(input("Select the split to load (enter the number): ")) - 1
if user_choice < 0 or user_choice >= len(available_splits):
raise ValueError("Invalid choice. Please try again.")
break
except ValueError as e:
print(e)

# Load the dataset based on user's choice
split = list(available_splits)[user_choice]

# Return iterable dataset in case of streaming
if streaming:
return builder_instance.as_streaming_dataset(split=split)
Expand All @@ -2101,11 +2123,12 @@ def load_dataset(
storage_options=storage_options,
)

# Build dataset for splits
# Build dataset for the chosen split
keep_in_memory = (
keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size)
)
ds = builder_instance.as_dataset(split=split, verification_mode=verification_mode, in_memory=keep_in_memory)

if save_infos:
builder_instance._save_infos()

Expand Down