Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ def load_dataset(
streaming: bool = False,
num_proc: Optional[int] = None,
storage_options: Optional[dict] = None,
columns: Optional[List[str]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:
"""Load a dataset from the Hugging Face Hub, or a local dataset.
Expand Down Expand Up @@ -1388,6 +1389,8 @@ def load_dataset(
(verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS
)

if path == "json" and columns is not None:
config_kwargs["columns"] = columns
# Create a dataset builder
builder_instance = load_dataset_builder(
path=path,
Expand Down
29 changes: 21 additions & 8 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig):
block_size: Optional[int] = None # deprecated
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None
columns: Optional[List[str]] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -107,14 +108,20 @@ def _generate_tables(self, files):
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)

# Filter only selected columns if specified
if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)

yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
else:
with open(file, "rb") as f:
batch_idx = 0
# Use block_size equal to the chunk size divided by 32 to leverage multithreading
# Set a default minimum value of 16kB if the chunk size is really small
Comment on lines -130 to -131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Wanted clarification on “the 2 others” to ensure no comment restorations were missed. Actually i have restored the two missing comments above - are they at the right place? :)

block_size = max(self.config.chunksize // 32, 16 << 10)
encoding_errors = (
self.config.encoding_errors if self.config.encoding_errors is not None else "strict"
Expand All @@ -123,12 +130,10 @@ def _generate_tables(self, files):
batch = f.read(self.config.chunksize)
if not batch:
break
# Finish current line
try:
batch += f.readline()
except (AttributeError, io.UnsupportedOperation):
batch += readline(f)
# PyArrow only accepts utf-8 encoded bytes
if self.config.encoding != "utf-8":
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
try:
Expand All @@ -137,6 +142,12 @@ def _generate_tables(self, files):
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
)
if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)
yield (file_idx, batch_idx), self._cast_table(pa_table)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this at the end, where you removed the yield - this way the try/except is only about the paj.read_json call

break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
if (
Expand All @@ -146,8 +157,6 @@ def _generate_tables(self, files):
):
raise
else:
# Increase the block size in case it was too small.
# The block size will be reset for the next file.
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
)
Expand All @@ -165,14 +174,18 @@ def _generate_tables(self, files):
df.columns = list(self.config.features) if self.config.features else ["text"]
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)
yield (file_idx, batch_idx), self._cast_table(pa_table)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
)
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
yield file_idx, self._cast_table(pa_table)
break
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1
19 changes: 18 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Standard library
import json
import tempfile
import textwrap

import pyarrow as pa
# Third-party
import pytest
import pyarrow as pa

# First-party (datasets)
from datasets import load_dataset
from datasets import Features, Value
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
Expand Down Expand Up @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.column_names == ["ID", "Language", "Topic"]

def test_load_dataset_json_with_columns_filtering():
sample = {"a": 1, "b": 2, "c": 3}

with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f:
f.write(json.dumps(sample) + "\n")
f.write(json.dumps(sample) + "\n")
path = f.name

dataset = load_dataset("json", data_files=path, columns=["a", "c"])
assert set(dataset["train"].column_names) == {"a", "c"}