diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index c5d8bcd03fc..f9b96fd74cd 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -50,6 +50,7 @@ class JsonConfig(datasets.BuilderConfig): block_size: Optional[int] = None # deprecated chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None + columns: Optional[List[str]] = None def __post_init__(self): super().__post_init__() @@ -109,84 +110,79 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table - def _generate_tables(self, files): - for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - # If the file is one json object and if we need to look at the items in one specific field + def _generate_tables(self, files: List[str]) -> Generator: + for file_idx, file in enumerate(files): if self.config.field is not None: - with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: - dataset = ujson_loads(f.read()) - # We keep only the field we are interested in - dataset = dataset[self.config.field] - df = pandas_read_json(io.StringIO(ujson_dumps(dataset))) - if df.columns.tolist() == [0]: - df.columns = list(self.config.features) if self.config.features else ["text"] - pa_table = pa.Table.from_pandas(df, preserve_index=False) - yield file_idx, self._cast_table(pa_table) - - # If the file has one json object per line - else: - with open(file, "rb") as f: - batch_idx = 0 - # Use block_size equal to the chunk size divided by 32 to leverage multithreading - # Set a default minimum value of 16kB if the chunk size is really small - block_size = max(self.config.chunksize // 32, 16 << 10) - encoding_errors = ( - self.config.encoding_errors if self.config.encoding_errors is not None else "strict" - ) - while True: - batch = f.read(self.config.chunksize) - if not batch: - break - # Finish current line - try: - batch += f.readline() - except (AttributeError, io.UnsupportedOperation): - batch += readline(f) - # PyArrow only accepts utf-8 encoded bytes - if self.config.encoding != "utf-8": - batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") - try: - while True: - try: - pa_table = paj.read_json( - io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) - ) - break - except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: - if ( - isinstance(e, pa.ArrowInvalid) - and "straddling" not in str(e) - or block_size > len(batch) - ): - raise - else: - # Increase the block size in case it was too small. - # The block size will be reset for the next file. - logger.debug( - f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." - ) - block_size *= 2 - except pa.ArrowInvalid as e: - try: - with open( - file, encoding=self.config.encoding, errors=self.config.encoding_errors - ) as f: - df = pandas_read_json(f) - except ValueError: - logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}") - raise e - if df.columns.tolist() == [0]: - df.columns = list(self.config.features) if self.config.features else ["text"] - try: - pa_table = pa.Table.from_pandas(df, preserve_index=False) - except pa.ArrowInvalid as e: - logger.error( - f"Failed to convert pandas DataFrame to Arrow Table from file '{file}' with error {type(e)}: {e}" + # Load JSON with field selection + try: + for batch_idx, json_obj in enumerate( + ijson.items( + open( + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, + ), + self.config.field, + ) + ): + pa_table = pa.Table.from_pandas(pd.DataFrame(json_obj)) + + if self.config.columns is not None: + missing_cols = set(self.config.columns) - set( + pa_table.column_names + ) + for col in missing_cols: + pa_table = pa_table.append_column( + col, pa.array([None] * pa_table.num_rows) ) - raise ValueError( - f"Failed to convert pandas DataFrame to Arrow Table from file {file}." - ) from None - yield file_idx, self._cast_table(pa_table) - break + pa_table = pa_table.select(self.config.columns) + yield (file_idx, batch_idx), self._cast_table(pa_table) - batch_idx += 1 + + except Exception as e: + raise DatasetGenerationError( + f"Failed to parse JSON with field {self.config.field}: {e}" + ) from e + + else: + # Load JSON line by line + batch_idx = 0 + while True: + try: + pa_table = paj.read_json( + file, + read_options=paj.ReadOptions( + use_threads=True, + block_size=1 << 20, + ), + parse_options=paj.ParseOptions(explicit_schema=None), + ) + break + + except pa.ArrowInvalid: + # Pandas fallback only if Arrow fails + with open( + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, + ) as f: + df = pandas_read_json(f) + pa_table = pa.Table.from_pandas(df) + break + + except StopIteration: + # End of file + return + + # Apply columns selection after table is ready + if self.config.columns is not None: + missing_cols = set(self.config.columns) - set( + pa_table.column_names + ) + for col in missing_cols: + pa_table = pa_table.append_column( + col, pa.array([None] * pa_table.num_rows) + ) + pa_table = pa_table.select(self.config.columns) + + yield (file_idx, batch_idx), self._cast_table(pa_table) diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 18f066b5e68..47efe35fa63 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -1,8 +1,14 @@ +# Standard library +import json +import tempfile import textwrap -import pyarrow as pa +# Third-party import pytest +import pyarrow as pa +# First-party (datasets) +from datasets import load_dataset from datasets import Features, Value from datasets.builder import InvalidConfigName from datasets.data_files import DataFilesList @@ -265,3 +271,14 @@ def test_json_generate_tables_with_sorted_columns(file_fixture, config_kwargs, r generator = builder._generate_tables([[request.getfixturevalue(file_fixture)]]) pa_table = pa.concat_tables([table for _, table in generator]) assert pa_table.column_names == ["ID", "Language", "Topic"] + +def test_load_dataset_json_with_columns_filtering(): + sample = {"a": 1, "b": 2, "c": 3} + + with tempfile.NamedTemporaryFile("w+", suffix=".jsonl", delete=False) as f: + f.write(json.dumps(sample) + "\n") + f.write(json.dumps(sample) + "\n") + path = f.name + + dataset = load_dataset("json", data_files=path, columns=["a", "c"]) + assert set(dataset["train"].column_names) == {"a", "c"}