Skip to content

Commit a2a83a8

Browse files
authored
Flatten dataset on the fly in save_to_disk (#5588)
1 parent c4f14de commit a2a83a8

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

src/datasets/arrow_dataset.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,8 +1345,6 @@ def save_to_disk(
13451345
if self.list_indexes():
13461346
raise ValueError("please remove all the indexes using `dataset.drop_index` before saving a dataset")
13471347

1348-
dataset = self.flatten_indices(num_proc=num_proc) if self._indices is not None else self
1349-
13501348
if is_local:
13511349
Path(dataset_path).resolve().mkdir(parents=True, exist_ok=True)
13521350
parent_cache_files_paths = {
@@ -1360,7 +1358,7 @@ def save_to_disk(
13601358

13611359
# Get json serializable state
13621360
state = {
1363-
key: dataset.__dict__[key]
1361+
key: self.__dict__[key]
13641362
for key in [
13651363
"_fingerprint",
13661364
"_format_columns",
@@ -1369,7 +1367,7 @@ def save_to_disk(
13691367
"_output_all_columns",
13701368
]
13711369
}
1372-
state["_split"] = str(dataset.split) if dataset.split is not None else dataset.split
1370+
state["_split"] = str(self.split) if self.split is not None else self.split
13731371
state["_data_files"] = [
13741372
{"filename": f"data-{shard_idx:05d}-of-{num_shards:05d}.arrow"} for shard_idx in range(num_shards)
13751373
]
@@ -1381,20 +1379,20 @@ def save_to_disk(
13811379
str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't."
13821380
) from None
13831381
# Get json serializable dataset info
1384-
dataset_info = asdict(dataset._info)
1382+
dataset_info = asdict(self._info)
13851383

13861384
shards_done = 0
13871385
pbar = logging.tqdm(
13881386
disable=not logging.is_progress_bar_enabled(),
13891387
unit=" examples",
1390-
total=len(dataset),
1388+
total=len(self),
13911389
leave=False,
13921390
desc=f"Saving the dataset ({shards_done}/{num_shards} shards)",
13931391
)
13941392
kwargs_per_job = (
13951393
{
13961394
"job_id": shard_idx,
1397-
"shard": dataset.shard(num_shards=num_shards, index=shard_idx, contiguous=True),
1395+
"shard": self.shard(num_shards=num_shards, index=shard_idx, contiguous=True),
13981396
"fpath": path_join(dataset_path, f"data-{shard_idx:05d}-of-{num_shards:05d}.arrow"),
13991397
"storage_options": storage_options,
14001398
}
@@ -1439,12 +1437,6 @@ def save_to_disk(
14391437
def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_options: Optional[dict]):
14401438
batch_size = config.DEFAULT_MAX_BATCH_SIZE
14411439

1442-
if shard._indices is not None:
1443-
raise ValueError(
1444-
"`_save_to_disk_single` only support shards with flattened indices. "
1445-
"Please call ds.flatten_indices() before saving to disk."
1446-
)
1447-
14481440
num_examples_progress_update = 0
14491441
writer = ArrowWriter(
14501442
features=shard.features,
@@ -1454,7 +1446,7 @@ def _save_to_disk_single(job_id: int, shard: "Dataset", fpath: str, storage_opti
14541446
)
14551447
try:
14561448
_time = time.time()
1457-
for pa_table in table_iter(shard.data, batch_size=batch_size):
1449+
for pa_table in shard.with_format("arrow").iter(batch_size):
14581450
writer.write_table(pa_table)
14591451
num_examples_progress_update += len(pa_table)
14601452
if time.time() > _time + config.PBAR_REFRESH_TIME_INTERVAL:

0 commit comments

Comments
 (0)