Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 74 additions & 78 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig):
block_size: Optional[int] = None # deprecated
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None
columns: Optional[List[str]] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -109,84 +110,79 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# If the file is one json object and if we need to look at the items in one specific field
def _generate_tables(self, files: List[str]) -> Generator:
for file_idx, file in enumerate(files):
if self.config.field is not None:
with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
dataset = ujson_loads(f.read())
# We keep only the field we are interested in
dataset = dataset[self.config.field]
df = pandas_read_json(io.StringIO(ujson_dumps(dataset)))
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
pa_table = pa.Table.from_pandas(df, preserve_index=False)
yield file_idx, self._cast_table(pa_table)

# If the file has one json object per line
else:
with open(file, "rb") as f:
batch_idx = 0
# Use block_size equal to the chunk size divided by 32 to leverage multithreading
# Set a default minimum value of 16kB if the chunk size is really small
Comment on lines -130 to -131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this comment deletion and the 2 others

Wanted clarification on “the 2 others” to ensure no comment restorations were missed. Actually i have restored the two missing comments above - are they at the right place? :)

block_size = max(self.config.chunksize // 32, 16 << 10)
encoding_errors = (
self.config.encoding_errors if self.config.encoding_errors is not None else "strict"
)
while True:
batch = f.read(self.config.chunksize)
if not batch:
break
# Finish current line
try:
batch += f.readline()
except (AttributeError, io.UnsupportedOperation):
batch += readline(f)
# PyArrow only accepts utf-8 encoded bytes
if self.config.encoding != "utf-8":
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
try:
while True:
try:
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
)
break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
if (
isinstance(e, pa.ArrowInvalid)
and "straddling" not in str(e)
or block_size > len(batch)
):
raise
else:
# Increase the block size in case it was too small.
# The block size will be reset for the next file.
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
)
block_size *= 2
except pa.ArrowInvalid as e:
try:
with open(
file, encoding=self.config.encoding, errors=self.config.encoding_errors
) as f:
df = pandas_read_json(f)
except ValueError:
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]
try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
except pa.ArrowInvalid as e:
logger.error(
f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}"
# Load JSON with field selection
try:
for batch_idx, json_obj in enumerate(
ijson.items(
open(
file,
encoding=self.config.encoding,
errors=self.config.encoding_errors,
),
self.config.field,
)
):
pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj))

if self.config.columns is not None:
missing_cols = set(self.config.columns) - set(
pa_table.column_names
)
for col in missing_cols:
pa_table = pa_table.append_column(
col, pa.array([None] * pa_table.num_rows)
)
raise ValueError(
f"Failed to convert pandas DataFrame to Arrow Table from file {file}."
) from None
yield file_idx, self._cast_table(pa_table)
break
pa_table = pa_table.select(self.config.columns)

yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1

except Exception as e:
raise DatasetGenerationError(
f"Failed to parse JSON with field {self.config.field}: {e}"
) from e

else:
# Load JSON line by line
batch_idx = 0
while True:
try:
pa_table = paj.read_json(
file,
read_options=paj.ReadOptions(
use_threads=True,
block_size=1 << 20,
),
parse_options=paj.ParseOptions(explicit_schema=None),
)
break

except pa.ArrowInvalid:
# Pandas fallback only if Arrow fails
with open(
file,
encoding=self.config.encoding,
errors=self.config.encoding_errors,
) as f:
df = pandas_read_json(f)
pa_table = pa.Table.from_pandas(df)
break

except StopIteration:
# End of file
return

# Apply columns selection after table is ready
if self.config.columns is not None:
missing_cols = set(self.config.columns) - set(
pa_table.column_names
)
for col in missing_cols:
pa_table = pa_table.append_column(
col, pa.array([None] * pa_table.num_rows)
)
pa_table = pa_table.select(self.config.columns)

yield (file_idx, batch_idx), self._cast_table(pa_table)
19 changes: 18 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Standard library
import json
import tempfile
import textwrap

import pyarrow as pa
# Third-party
import pytest
import pyarrow as pa

# First-party (datasets)
from datasets import load_dataset
from datasets import Features, Value
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
Expand Down Expand Up @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r
generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.column_names == ["ID", "Language", "Topic"]

def test_load_dataset_json_with_columns_filtering():
sample = {"a": 1, "b": 2, "c": 3}

with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f:
f.write(json.dumps(sample) + "\n")
f.write(json.dumps(sample) + "\n")
path = f.name

dataset = load_dataset("json", data_files=path, columns=["a", "c"])
assert set(dataset["train"].column_names) == {"a", "c"}