Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 57 additions & 43 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig):
block_size: Optional[int] = None # deprecated
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None
columns: Optional[List[str]] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -121,6 +122,14 @@ def _generate_tables(self, files):
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)

# Filter only selected columns if specified
if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)

yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
Expand All @@ -137,56 +146,61 @@ def _generate_tables(self, files):
batch = f.read(self.config.chunksize)
if not batch:
break
# Finish current line
try:
batch += f.readline()
except (AttributeError, io.UnsupportedOperation):
batch += readline(f)
# PyArrow only accepts utf-8 encoded bytes
if self.config.encoding != "utf-8":
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
try:
while True:
try:
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
)
break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
if (
isinstance(e, pa.ArrowInvalid)
and "straddling" not in str(e)
or block_size > len(batch)
):
raise
else:
# Increase the block size in case it was too small.
# The block size will be reset for the next file.
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
)
block_size *= 2
except pa.ArrowInvalid as e:
try:
with open(
file, encoding=self.config.encoding, errors=self.config.encoding_errors
) as f:
df = pandas_read_json(f)
except ValueError:
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]

while True:
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
)
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
yield file_idx, self._cast_table(pa_table)
break
break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
if (
isinstance(e, pa.ArrowInvalid)
and "straddling" not in str(e)
) or block_size > len(batch):
raise
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. "
f"Retrying with block_size={block_size * 2}."
)
block_size *= 2

if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)

yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1

# Pandas fallback in case of ArrowInvalid
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is not at the right location anymore: it should trigger on ArrowInvalid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve moved the Pandas fallback into the except pa.ArrowInvalid block, will you check?

with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
df = pandas_read_json(f)
except ValueError as e:
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
if self.config.columns is not None:
missing_cols = [col for col in self.config.columns if col not in pa_table.column_names]
for col in missing_cols:
pa_table = pa_table.append_column(col, pa.array([None] * pa_table.num_rows))
pa_table = pa_table.select(self.config.columns)
yield (file_idx, batch_idx), self._cast_table(pa_table)
except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
)
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
19 changes: 18 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Standard library
import json
import tempfile
import textwrap

import pyarrow as pa
# Third-party
import pytest
import pyarrow as pa

# First-party (datasets)
from datasets import load_dataset
from datasets import Features, Value
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
Expand Down Expand Up @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.column_names == ["ID", "Language", "Topic"]

def test_load_dataset_json_with_columns_filtering():
sample = {"a": 1, "b": 2, "c": 3}

with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f:
f.write(json.dumps(sample) + "\n")
f.write(json.dumps(sample) + "\n")
path = f.name

dataset = load_dataset("json", data_files=path, columns=["a", "c"])
assert set(dataset["train"].column_names) == {"a", "c"}