Skip to content

Commit c4f14de

Browse files
authored
Do no write index by default when exporting a dataset (#5583)
1 parent 939b233 commit c4f14de

File tree

8 files changed

+62
-55
lines changed

8 files changed

+62
-55
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
"lz4",
171171
"py7zr",
172172
"rarfile>=4.0",
173+
"sqlalchemy<2.0.0",
173174
"s3fs>=2021.11.1;python_version<'3.8'", # aligned with fsspec[http]>=2021.11.1; test only on python 3.7 for now
174175
"tensorflow>=2.3,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
175176
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
@@ -196,7 +197,6 @@
196197
"scipy",
197198
"sentencepiece", # for bleurt
198199
"seqeval",
199-
"sqlalchemy<2.0.0",
200200
"spacy>=3.0.0",
201201
"tldextract",
202202
# to speed up pip backtracking

src/datasets/arrow_dataset.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4597,7 +4597,6 @@ def to_csv(
45974597
path_or_buf: Union[PathLike, BinaryIO],
45984598
batch_size: Optional[int] = None,
45994599
num_proc: Optional[int] = None,
4600-
index: bool = False,
46014600
**to_csv_kwargs,
46024601
) -> int:
46034602
"""Exports the dataset to csv
@@ -4613,20 +4612,18 @@ def to_csv(
46134612
use multiprocessing. `batch_size` in this case defaults to
46144613
`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default
46154614
value if you have sufficient compute power.
4616-
index (`bool`, default `False`): Write row names (index).
4615+
**to_csv_kwargs (additional keyword arguments):
4616+
Parameters to pass to pandas's [`pandas.DataFrame.to_csv`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html).
46174617
46184618
<Changed version="2.10.0">
46194619
4620-
Now, `index` defaults to `False`.
4620+
Now, `index` defaults to `False` if not specified.
46214621
4622-
If you would like to write the index, set it to `True` and also set a name for the index column by
4622+
If you would like to write the index, pass `index=True` and also set a name for the index column by
46234623
passing `index_label`.
46244624
46254625
</Changed>
46264626
4627-
**to_csv_kwargs (additional keyword arguments):
4628-
Parameters to pass to pandas's `pandas.DataFrame.to_csv`.
4629-
46304627
Returns:
46314628
`int`: The number of characters or bytes written.
46324629
@@ -4639,9 +4636,7 @@ def to_csv(
46394636
# Dynamic import to avoid circular dependency
46404637
from .io.csv import CsvDatasetWriter
46414638

4642-
return CsvDatasetWriter(
4643-
self, path_or_buf, batch_size=batch_size, num_proc=num_proc, index=index, **to_csv_kwargs
4644-
).write()
4639+
return CsvDatasetWriter(self, path_or_buf, batch_size=batch_size, num_proc=num_proc, **to_csv_kwargs).write()
46454640

46464641
def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]:
46474642
"""Returns the dataset as a Python dict. Can also return a generator for large datasets.
@@ -4699,22 +4694,17 @@ def to_json(
46994694
use multiprocessing. `batch_size` in this case defaults to
47004695
`datasets.config.DEFAULT_MAX_BATCH_SIZE` but feel free to make it 5x or 10x of the default
47014696
value if you have sufficient compute power.
4702-
lines (`bool`, defaults to `True`):
4703-
Whether output JSON lines format.
4704-
Only possible if `orient="records"`. It will throw ValueError with `orient` different from
4705-
`"records"`, since the others are not list-like.
4706-
orient (`str`, defaults to `"records"`):
4707-
Format of the JSON:
4708-
4709-
- `"records"`: list like `[{column -> value}, … , {column -> value}]`
4710-
- `"split"`: dict like `{"index" -> [index], "columns" -> [columns], "data" -> [values]}`
4711-
- `"index"`: dict like `{index -> {column -> value}}`
4712-
- `"columns"`: dict like `{column -> {index -> value}}`
4713-
- `"values"`: just the values array
4714-
- `"table"`: dict like `{"schema": {schema}, "data": {data}}`
47154697
**to_json_kwargs (additional keyword arguments):
47164698
Parameters to pass to pandas's [`pandas.DataFrame.to_json`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_json.html).
47174699
4700+
<Changed version="2.11.0">
4701+
4702+
Now, `index` defaults to `False` if `orint` is `"split"` or `"table"` is specified.
4703+
4704+
If you would like to write the index, pass `index=True`.
4705+
4706+
</Changed>
4707+
47184708
Returns:
47194709
`int`: The number of characters or bytes written.
47204710
@@ -4817,7 +4807,16 @@ def to_sql(
48174807
Size of the batch to load in memory and write at once.
48184808
Defaults to `datasets.config.DEFAULT_MAX_BATCH_SIZE`.
48194809
**sql_writer_kwargs (additional keyword arguments):
4820-
Parameters to pass to pandas's [`Dataframe.to_sql`].
4810+
Parameters to pass to pandas's [`pandas.DataFrame.to_sql`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html).
4811+
4812+
<Changed version="2.11.0">
4813+
4814+
Now, `index` defaults to `False` if not specified.
4815+
4816+
If you would like to write the index, pass `index=True` and also set a name for the index column by
4817+
passing `index_label`.
4818+
4819+
</Changed>
48214820
48224821
Returns:
48234822
`int`: The number of records written.

src/datasets/io/csv.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,29 +86,30 @@ def __init__(
8686

8787
def write(self) -> int:
8888
_ = self.to_csv_kwargs.pop("path_or_buf", None)
89+
header = self.to_csv_kwargs.pop("header", True)
8990
index = self.to_csv_kwargs.pop("index", False)
9091

9192
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
9293
with open(self.path_or_buf, "wb+") as buffer:
93-
written = self._write(file_obj=buffer, index=index, **self.to_csv_kwargs)
94+
written = self._write(file_obj=buffer, header=header, index=index, **self.to_csv_kwargs)
9495
else:
95-
written = self._write(file_obj=self.path_or_buf, index=index, **self.to_csv_kwargs)
96+
written = self._write(file_obj=self.path_or_buf, header=header, index=index, **self.to_csv_kwargs)
9697
return written
9798

9899
def _batch_csv(self, args):
99-
offset, header, to_csv_kwargs = args
100+
offset, header, index, to_csv_kwargs = args
100101

101102
batch = query_table(
102103
table=self.dataset.data,
103104
key=slice(offset, offset + self.batch_size),
104105
indices=self.dataset._indices,
105106
)
106107
csv_str = batch.to_pandas().to_csv(
107-
path_or_buf=None, header=header if (offset == 0) else False, **to_csv_kwargs
108+
path_or_buf=None, header=header if (offset == 0) else False, index=index, **to_csv_kwargs
108109
)
109110
return csv_str.encode(self.encoding)
110111

111-
def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> int:
112+
def _write(self, file_obj: BinaryIO, header, index, **to_csv_kwargs) -> int:
112113
"""Writes the pyarrow table as CSV to a binary file handle.
113114
114115
Caller is responsible for opening and closing the handle.
@@ -122,7 +123,7 @@ def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> in
122123
disable=not logging.is_progress_bar_enabled(),
123124
desc="Creating CSV from Arrow format",
124125
):
125-
csv_str = self._batch_csv((offset, header, to_csv_kwargs))
126+
csv_str = self._batch_csv((offset, header, index, to_csv_kwargs))
126127
written += file_obj.write(csv_str)
127128

128129
else:
@@ -131,7 +132,7 @@ def _write(self, file_obj: BinaryIO, header: bool = True, **to_csv_kwargs) -> in
131132
for csv_str in logging.tqdm(
132133
pool.imap(
133134
self._batch_csv,
134-
[(offset, header, to_csv_kwargs) for offset in range(0, num_rows, batch_size)],
135+
[(offset, header, index, to_csv_kwargs) for offset in range(0, num_rows, batch_size)],
135136
),
136137
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
137138
unit="ba",

src/datasets/io/json.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,33 +92,38 @@ def __init__(
9292
def write(self) -> int:
9393
_ = self.to_json_kwargs.pop("path_or_buf", None)
9494
orient = self.to_json_kwargs.pop("orient", "records")
95-
lines = self.to_json_kwargs.pop("lines", True)
95+
lines = self.to_json_kwargs.pop("lines", True if orient == "records" else False)
96+
index = self.to_json_kwargs.pop("index", False if orient in ["split", "table"] else True)
9697
compression = self.to_json_kwargs.pop("compression", None)
9798

9899
if compression not in [None, "infer", "gzip", "bz2", "xz"]:
99100
raise NotImplementedError(f"`datasets` currently does not support {compression} compression")
100101

101102
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
102103
with fsspec.open(self.path_or_buf, "wb", compression=compression) as buffer:
103-
written = self._write(file_obj=buffer, orient=orient, lines=lines, **self.to_json_kwargs)
104+
written = self._write(file_obj=buffer, orient=orient, lines=lines, index=index, **self.to_json_kwargs)
104105
else:
105106
if compression:
106107
raise NotImplementedError(
107108
f"The compression parameter is not supported when writing to a buffer, but compression={compression}"
108109
" was passed. Please provide a local path instead."
109110
)
110-
written = self._write(file_obj=self.path_or_buf, orient=orient, lines=lines, **self.to_json_kwargs)
111+
written = self._write(
112+
file_obj=self.path_or_buf, orient=orient, lines=lines, index=index, **self.to_json_kwargs
113+
)
111114
return written
112115

113116
def _batch_json(self, args):
114-
offset, orient, lines, to_json_kwargs = args
117+
offset, orient, lines, index, to_json_kwargs = args
115118

116119
batch = query_table(
117120
table=self.dataset.data,
118121
key=slice(offset, offset + self.batch_size),
119122
indices=self.dataset._indices,
120123
)
121-
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)
124+
json_str = batch.to_pandas().to_json(
125+
path_or_buf=None, orient=orient, lines=lines, index=index, **to_json_kwargs
126+
)
122127
if not json_str.endswith("\n"):
123128
json_str += "\n"
124129
return json_str.encode(self.encoding)
@@ -128,6 +133,7 @@ def _write(
128133
file_obj: BinaryIO,
129134
orient,
130135
lines,
136+
index,
131137
**to_json_kwargs,
132138
) -> int:
133139
"""Writes the pyarrow table as JSON lines to a binary file handle.
@@ -143,15 +149,15 @@ def _write(
143149
disable=not logging.is_progress_bar_enabled(),
144150
desc="Creating json from Arrow format",
145151
):
146-
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
152+
json_str = self._batch_json((offset, orient, lines, index, to_json_kwargs))
147153
written += file_obj.write(json_str)
148154
else:
149155
num_rows, batch_size = len(self.dataset), self.batch_size
150156
with multiprocessing.Pool(self.num_proc) as pool:
151157
for json_str in logging.tqdm(
152158
pool.imap(
153159
self._batch_json,
154-
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
160+
[(offset, orient, lines, index, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
155161
),
156162
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
157163
unit="ba",

src/datasets/io/sql.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,24 @@ def __init__(
7777
def write(self) -> int:
7878
_ = self.to_sql_kwargs.pop("sql", None)
7979
_ = self.to_sql_kwargs.pop("con", None)
80+
index = self.to_sql_kwargs.pop("index", False)
8081

81-
written = self._write(**self.to_sql_kwargs)
82+
written = self._write(index=index, **self.to_sql_kwargs)
8283
return written
8384

8485
def _batch_sql(self, args):
85-
offset, to_sql_kwargs = args
86+
offset, index, to_sql_kwargs = args
8687
to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs
8788
batch = query_table(
8889
table=self.dataset.data,
8990
key=slice(offset, offset + self.batch_size),
9091
indices=self.dataset._indices,
9192
)
9293
df = batch.to_pandas()
93-
num_rows = df.to_sql(self.name, self.con, **to_sql_kwargs)
94+
num_rows = df.to_sql(self.name, self.con, index=index, **to_sql_kwargs)
9495
return num_rows or len(df)
9596

96-
def _write(self, **to_sql_kwargs) -> int:
97+
def _write(self, index, **to_sql_kwargs) -> int:
9798
"""Writes the pyarrow table as SQL to a database.
9899
99100
Caller is responsible for opening and closing the SQL connection.
@@ -107,14 +108,14 @@ def _write(self, **to_sql_kwargs) -> int:
107108
disable=not logging.is_progress_bar_enabled(),
108109
desc="Creating SQL from Arrow format",
109110
):
110-
written += self._batch_sql((offset, to_sql_kwargs))
111+
written += self._batch_sql((offset, index, to_sql_kwargs))
111112
else:
112113
num_rows, batch_size = len(self.dataset), self.batch_size
113114
with multiprocessing.Pool(self.num_proc) as pool:
114115
for num_rows in logging.tqdm(
115116
pool.imap(
116117
self._batch_sql,
117-
[(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)],
118+
[(offset, index, to_sql_kwargs) for offset in range(0, num_rows, batch_size)],
118119
),
119120
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
120121
unit="ba",

tests/io/test_json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_dataset_to_json_lines(self, lines, load_json_function, dataset):
188188
"orient, container, keys, len_at",
189189
[
190190
("records", list, {"tokens", "labels", "answers", "id"}, None),
191-
("split", dict, {"index", "columns", "data"}, "data"),
191+
("split", dict, {"columns", "data"}, "data"),
192192
("index", dict, set("0123456789"), None),
193193
("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
194194
("values", list, None, None),
@@ -227,7 +227,7 @@ def test_dataset_to_json_lines_multiproc(self, lines, load_json_function, datase
227227
"orient, container, keys, len_at",
228228
[
229229
("records", list, {"tokens", "labels", "answers", "id"}, None),
230-
("split", dict, {"index", "columns", "data"}, "data"),
230+
("split", dict, {"columns", "data"}, "data"),
231231
("index", dict, set("0123456789"), None),
232232
("columns", dict, {"tokens", "labels", "answers", "id"}, "tokens"),
233233
("values", list, None, None),

tests/io/test_sql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_dataset_to_sql(sqlite_path, tmp_path):
6666
cache_dir = tmp_path / "cache"
6767
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
6868
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
69-
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=1).write()
69+
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=1).write()
7070

7171
original_sql = iter_sql_file(sqlite_path)
7272
expected_sql = iter_sql_file(output_sqlite_path)
@@ -80,7 +80,7 @@ def test_dataset_to_sql_multiproc(sqlite_path, tmp_path):
8080
cache_dir = tmp_path / "cache"
8181
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
8282
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
83-
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=2).write()
83+
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=2).write()
8484

8585
original_sql = iter_sql_file(sqlite_path)
8686
expected_sql = iter_sql_file(output_sqlite_path)
@@ -95,4 +95,4 @@ def test_dataset_to_sql_invalidproc(sqlite_path, tmp_path):
9595
output_sqlite_path = os.path.join(cache_dir, "tmp.sql")
9696
dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read()
9797
with pytest.raises(ValueError):
98-
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=0).write()
98+
SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, num_proc=0).write()

tests/test_arrow_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,7 @@ def test_to_sql(self, in_memory):
22592259
# Destionation specified as database URI string
22602260
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
22612261
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2262-
_ = dset.to_sql("data", "sqlite:///" + file_path, index=False)
2262+
_ = dset.to_sql("data", "sqlite:///" + file_path)
22632263

22642264
self.assertTrue(os.path.isfile(file_path))
22652265
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
@@ -2273,7 +2273,7 @@ def test_to_sql(self, in_memory):
22732273

22742274
file_path = os.path.join(tmp_dir, "test_path.sqlite")
22752275
with contextlib.closing(sqlite3.connect(file_path)) as con:
2276-
_ = dset.to_sql("data", con, index=False, if_exists="replace")
2276+
_ = dset.to_sql("data", con, if_exists="replace")
22772277

22782278
self.assertTrue(os.path.isfile(file_path))
22792279
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
@@ -2284,7 +2284,7 @@ def test_to_sql(self, in_memory):
22842284
# Test writing to a database in chunks
22852285
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
22862286
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2287-
_ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, index=False, if_exists="replace")
2287+
_ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, if_exists="replace")
22882288

22892289
self.assertTrue(os.path.isfile(file_path))
22902290
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
@@ -2296,7 +2296,7 @@ def test_to_sql(self, in_memory):
22962296
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
22972297
dset = dset.select(range(0, len(dset), 2)).shuffle()
22982298
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2299-
_ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace")
2299+
_ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
23002300

23012301
self.assertTrue(os.path.isfile(file_path))
23022302
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)
@@ -2307,7 +2307,7 @@ def test_to_sql(self, in_memory):
23072307
# With array features
23082308
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
23092309
file_path = os.path.join(tmp_dir, "test_path.sqlite")
2310-
_ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace")
2310+
_ = dset.to_sql("data", "sqlite:///" + file_path, if_exists="replace")
23112311

23122312
self.assertTrue(os.path.isfile(file_path))
23132313
sql_dset = pd.read_sql("data", "sqlite:///" + file_path)

0 commit comments

Comments
 (0)