3131from .features import Features
3232from .features .features import FeatureType
3333from .info import DatasetInfo , DatasetInfosDict
34+ from .iterable_dataset import IterableDataset
3435from .naming import _split_re
3536from .splits import NamedSplit , Split , SplitDict , SplitInfo
3637from .table import Table
@@ -49,7 +50,7 @@ def __call__(self, *fn_args, **fn_kwargs):
4950 return self .func (* fn_args , * self .args , ** fn_kwargs )
5051
5152
52- class DatasetDict (dict ):
53+ class DatasetDict (dict [ Union [ str , NamedSplit ], "Dataset" ] ):
5354 """A dictionary (dict of str: datasets.Dataset) with dataset transforms methods (map, filter, etc.)"""
5455
5556 def _check_values_type (self ):
@@ -1616,6 +1617,7 @@ def push_to_hub(
16161617 max_shard_size : Optional [Union [int , str ]] = None ,
16171618 num_shards : Optional [dict [str , int ]] = None ,
16181619 embed_external_files : bool = True ,
1620+ num_proc : Optional [int ] = None ,
16191621 ) -> CommitInfo :
16201622 """Pushes the [`DatasetDict`] to the hub as a Parquet dataset.
16211623 The [`DatasetDict`] is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
@@ -1676,6 +1678,12 @@ def push_to_hub(
16761678 In particular, this will do the following before the push for the fields of type:
16771679
16781680 - [`Audio`] and [`Image`] removes local path information and embed file content in the Parquet files.
1681+ num_proc (`int`, *optional*, defaults to `None`):
1682+ Number of processes when preparing and uploading the dataset.
1683+ This is helpful if the dataset is made of many samples or media files to embed.
1684+ Multiprocessing is disabled by default.
1685+
1686+ <Added version="4.0.0"/>
16791687
16801688 Return:
16811689 huggingface_hub.CommitInfo
@@ -1756,6 +1764,7 @@ def push_to_hub(
17561764 max_shard_size = max_shard_size ,
17571765 num_shards = num_shards .get (split ),
17581766 embed_external_files = embed_external_files ,
1767+ num_proc = num_proc ,
17591768 )
17601769 additions += split_additions
17611770 total_uploaded_size += uploaded_size
@@ -1910,12 +1919,61 @@ def push_to_hub(
19101919 return commit_info
19111920
19121921
1913- class IterableDatasetDict (dict ):
1922+ class IterableDatasetDict (dict [Union [str , NamedSplit ], IterableDataset ]):
1923+ def _check_values_type (self ):
1924+ for dataset in self .values ():
1925+ if not isinstance (dataset , IterableDataset ):
1926+ raise TypeError (f"Values in `DatasetDict` should be of type `Dataset` but got type '{ type (dataset )} '" )
1927+
1928+ def _check_values_features (self ):
1929+ items = [(key , dataset ._resolve_features ()) for key , dataset in self .items ()]
1930+ for item_a , item_b in zip (items [:- 1 ], items [1 :]):
1931+ if item_a [1 ].features != item_b [1 ].features :
1932+ raise ValueError (
1933+ f"All datasets in `DatasetDict` should have the same features but features for '{ item_a [0 ]} ' and '{ item_b [0 ]} ' don't match: { item_a [1 ].features } != { item_b [1 ].features } "
1934+ )
1935+
19141936 def __repr__ (self ):
19151937 repr = "\n " .join ([f"{ k } : { v } " for k , v in self .items ()])
19161938 repr = re .sub (r"^" , " " * 4 , repr , count = 0 , flags = re .M )
19171939 return f"IterableDatasetDict({{\n { repr } \n }})"
19181940
1941+ @property
1942+ def num_columns (self ) -> dict [str , Optional [int ]]:
1943+ """Number of columns in each split of the dataset.
1944+ This can contain None valies if some splits have unknown features (e.g. after a map() operation).
1945+
1946+ Example:
1947+
1948+ ```py
1949+ >>> from datasets import load_dataset
1950+ >>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes")
1951+ >>> ds.num_columns
1952+ {'test': 2, 'train': 2, 'validation': 2}
1953+ ```
1954+ """
1955+ self ._check_values_type ()
1956+ return {k : dataset .num_columns for k , dataset in self .items ()}
1957+
1958+ @property
1959+ def column_names (self ) -> dict [str , Optional [list [str ]]]:
1960+ """Names of the columns in each split of the dataset.
1961+ This can contain None valies if some splits have unknown features (e.g. after a map() operation).
1962+
1963+ Example:
1964+
1965+ ```py
1966+ >>> from datasets import load_dataset
1967+ >>> ds = load_dataset("cornell-movie-review-data/rotten_tomatoes")
1968+ >>> ds.column_names
1969+ {'test': ['text', 'label'],
1970+ 'train': ['text', 'label'],
1971+ 'validation': ['text', 'label']}
1972+ ```
1973+ """
1974+ self ._check_values_type ()
1975+ return {k : dataset .column_names for k , dataset in self .items ()}
1976+
19191977 def with_format (
19201978 self ,
19211979 type : Optional [str ] = None ,
@@ -2385,6 +2443,7 @@ def push_to_hub(
23852443 # max_shard_size: Optional[Union[int, str]] = None, # TODO(QL): add arg
23862444 num_shards : Optional [dict [str , int ]] = None ,
23872445 embed_external_files : bool = True ,
2446+ num_proc : Optional [int ] = None ,
23882447 ) -> CommitInfo :
23892448 """Pushes the [`DatasetDict`] to the hub as a Parquet dataset.
23902449 The [`DatasetDict`] is pushed using HTTP requests and does not need to have neither git or git-lfs installed.
@@ -2436,6 +2495,12 @@ def push_to_hub(
24362495 In particular, this will do the following before the push for the fields of type:
24372496
24382497 - [`Audio`] and [`Image`] removes local path information and embed file content in the Parquet files.
2498+ num_proc (`int`, *optional*, defaults to `None`):
2499+ Number of processes when preparing and uploading the dataset.
2500+ This is helpful if the dataset is made of many samples or media files to embed.
2501+ Multiprocessing is disabled by default.
2502+
2503+ <Added version="4.0.0"/>
24392504
24402505 Return:
24412506 huggingface_hub.CommitInfo
@@ -2505,7 +2570,7 @@ def push_to_hub(
25052570 for split in self .keys ():
25062571 logger .info (f"Pushing split { split } to the Hub." )
25072572 # The split=key needs to be removed before merging
2508- split_additions , uploaded_size , dataset_nbytes = self [split ]._push_parquet_shards_to_hub (
2573+ split_additions , uploaded_size , dataset_nbytes , num_examples = self [split ]._push_parquet_shards_to_hub (
25092574 repo_id ,
25102575 data_dir = data_dir ,
25112576 split = split ,
@@ -2515,11 +2580,12 @@ def push_to_hub(
25152580 # max_shard_size=max_shard_size, # TODO(QL): add arg
25162581 num_shards = num_shards .get (split ),
25172582 embed_external_files = embed_external_files ,
2583+ num_proc = num_proc ,
25182584 )
25192585 additions += split_additions
25202586 total_uploaded_size += uploaded_size
25212587 total_dataset_nbytes += dataset_nbytes
2522- info_to_dump .splits [split ] = SplitInfo (str (split ), num_bytes = dataset_nbytes , num_examples = len ( self [ split ]) )
2588+ info_to_dump .splits [split ] = SplitInfo (str (split ), num_bytes = dataset_nbytes , num_examples = num_examples )
25232589 info_to_dump .download_checksums = None
25242590 info_to_dump .download_size = total_uploaded_size
25252591 info_to_dump .dataset_size = total_dataset_nbytes
0 commit comments