Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:
pa_table = table_cast(pa_table, self.config.features.arrow_schema)
return pa_table


def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# If the file is one json object and if we need to look at the items in one specific field
Expand All @@ -113,8 +114,6 @@ def _generate_tables(self, files):
else:
with open(file, "rb") as f:
batch_idx = 0
# Use block_size equal to the chunk size divided by 32 to leverage multithreading
# Set a default minimum value of 16kB if the chunk size is really small
block_size = max(self.config.chunksize // 32, 16 << 10)
encoding_errors = (
self.config.encoding_errors if self.config.encoding_errors is not None else "strict"
Expand All @@ -123,19 +122,18 @@ def _generate_tables(self, files):
batch = f.read(self.config.chunksize)
if not batch:
break
# Finish current line
try:
batch += f.readline()
except (AttributeError, io.UnsupportedOperation):
batch += readline(f)
# PyArrow only accepts utf-8 encoded bytes
if self.config.encoding != "utf-8":
batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8")
try:
while True:
try:
pa_table = paj.read_json(
io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size)
io.BytesIO(batch),
read_options=paj.ReadOptions(block_size=block_size),
)
break
except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e:
Expand All @@ -146,23 +144,31 @@ def _generate_tables(self, files):
):
raise
else:
# Increase the block size in case it was too small.
# The block size will be reset for the next file.
logger.debug(
f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}."
)
block_size *= 2
except pa.ArrowInvalid as e:
try:
with open(
file, encoding=self.config.encoding, errors=self.config.encoding_errors
file,
encoding=self.config.encoding,
errors=self.config.encoding_errors,
) as f:
df = pandas_read_json(f)
except ValueError:
logger.error(f"Failed to load JSON from file '{file}' with error {type(e)}: {e}")
raise e
if df.columns.tolist() == [0]:
df.columns = list(self.config.features) if self.config.features else ["text"]

# ✅ FIX: Coerce float-looking ints (like 0.0, 1.0) back to float64
for col in df.columns:
col_data = df[col].dropna()
if col_data.apply(lambda x: isinstance(x, float)).all():
if col_data.apply(lambda x: x.is_integer()).all():
df[col] = df[col].astype("float64")

try:
pa_table = pa.Table.from_pandas(df, preserve_index=False)
except pa.ArrowInvalid as e:
Expand All @@ -176,3 +182,4 @@ def _generate_tables(self, files):
break
yield (file_idx, batch_idx), self._cast_table(pa_table)
batch_idx += 1