diff --git a/src/datasets/load.py b/src/datasets/load.py index 458b917c4f5..cf5fd33c545 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -2088,6 +2088,28 @@ def load_dataset( **config_kwargs, ) + # Get available splits + available_splits = builder_instance.info.splits.keys() + + # If there are subdatasets or splits, print them and ask for user input + if available_splits: + print("Available splits or subdatasets:") + for i, split_name in enumerate(available_splits): + print(f"{i + 1}. {split_name}") + + # Ask for user input to choose the split/subdataset + while True: + try: + user_choice = int(input("Select the split to load (enter the number): ")) - 1 + if user_choice < 0 or user_choice >= len(available_splits): + raise ValueError("Invalid choice. Please try again.") + break + except ValueError as e: + print(e) + + # Load the dataset based on user's choice + split = list(available_splits)[user_choice] + # Return iterable dataset in case of streaming if streaming: return builder_instance.as_streaming_dataset(split=split) @@ -2101,11 +2123,12 @@ def load_dataset( storage_options=storage_options, ) - # Build dataset for splits + # Build dataset for the chosen split keep_in_memory = ( keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size) ) ds = builder_instance.as_dataset(split=split, verification_mode=verification_mode, in_memory=keep_in_memory) + if save_infos: builder_instance._save_infos()