@@ -1345,8 +1345,6 @@ def save_to_disk(
13451345 if self .list_indexes ():
13461346 raise ValueError ("please remove all the indexes using `dataset.drop_index` before saving a dataset" )
13471347
1348- dataset = self .flatten_indices (num_proc = num_proc ) if self ._indices is not None else self
1349-
13501348 if is_local :
13511349 Path (dataset_path ).resolve ().mkdir (parents = True , exist_ok = True )
13521350 parent_cache_files_paths = {
@@ -1360,7 +1358,7 @@ def save_to_disk(
13601358
13611359 # Get json serializable state
13621360 state = {
1363- key : dataset .__dict__ [key ]
1361+ key : self .__dict__ [key ]
13641362 for key in [
13651363 "_fingerprint" ,
13661364 "_format_columns" ,
@@ -1369,7 +1367,7 @@ def save_to_disk(
13691367 "_output_all_columns" ,
13701368 ]
13711369 }
1372- state ["_split" ] = str (dataset .split ) if dataset .split is not None else dataset .split
1370+ state ["_split" ] = str (self .split ) if self .split is not None else self .split
13731371 state ["_data_files" ] = [
13741372 {"filename" : f"data-{ shard_idx :05d} -of-{ num_shards :05d} .arrow" } for shard_idx in range (num_shards )
13751373 ]
@@ -1381,20 +1379,20 @@ def save_to_disk(
13811379 str (e ) + f"\n The format kwargs must be JSON serializable, but key '{ k } ' isn't."
13821380 ) from None
13831381 # Get json serializable dataset info
1384- dataset_info = asdict (dataset ._info )
1382+ dataset_info = asdict (self ._info )
13851383
13861384 shards_done = 0
13871385 pbar = logging .tqdm (
13881386 disable = not logging .is_progress_bar_enabled (),
13891387 unit = " examples" ,
1390- total = len (dataset ),
1388+ total = len (self ),
13911389 leave = False ,
13921390 desc = f"Saving the dataset ({ shards_done } /{ num_shards } shards)" ,
13931391 )
13941392 kwargs_per_job = (
13951393 {
13961394 "job_id" : shard_idx ,
1397- "shard" : dataset .shard (num_shards = num_shards , index = shard_idx , contiguous = True ),
1395+ "shard" : self .shard (num_shards = num_shards , index = shard_idx , contiguous = True ),
13981396 "fpath" : path_join (dataset_path , f"data-{ shard_idx :05d} -of-{ num_shards :05d} .arrow" ),
13991397 "storage_options" : storage_options ,
14001398 }
@@ -1439,12 +1437,6 @@ def save_to_disk(
14391437 def _save_to_disk_single (job_id : int , shard : "Dataset" , fpath : str , storage_options : Optional [dict ]):
14401438 batch_size = config .DEFAULT_MAX_BATCH_SIZE
14411439
1442- if shard ._indices is not None :
1443- raise ValueError (
1444- "`_save_to_disk_single` only support shards with flattened indices. "
1445- "Please call ds.flatten_indices() before saving to disk."
1446- )
1447-
14481440 num_examples_progress_update = 0
14491441 writer = ArrowWriter (
14501442 features = shard .features ,
@@ -1454,7 +1446,7 @@ def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_opti
14541446 )
14551447 try :
14561448 _time = time .time ()
1457- for pa_table in table_iter ( shard .data , batch_size = batch_size ):
1449+ for pa_table in shard .with_format ( "arrow" ). iter ( batch_size ):
14581450 writer .write_table (pa_table )
14591451 num_examples_progress_update += len (pa_table )
14601452 if time .time () > _time + config .PBAR_REFRESH_TIME_INTERVAL :
0 commit comments