Skip to content

Commit 1ddb40e

Browse files
Adds parquet writer (huggingface#103)
* added parquet writer * nit * Update src/datatrove/pipeline/writers/parquet.py Co-authored-by: Mario Šaško <mariosasko777@gmail.com> * updated test * nit --------- Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
1 parent b517728 commit 1ddb40e

File tree

6 files changed

+98
-8
lines changed

6 files changed

+98
-8
lines changed

src/datatrove/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"):
2323

2424
def get_file(self, filename):
2525
"""
26-
Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache
26+
Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache
2727
Args:
2828
filename: name of the file to open/get if previously opened
2929
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .jsonl import JsonlWriter
2+
from .parquet import ParquetWriter

src/datatrove/pipeline/writers/disk_base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dataclasses
22
from abc import ABC, abstractmethod
33
from string import Template
4-
from typing import Callable
4+
from typing import IO, Callable
55

66
from datatrove.data import Document, DocumentsPipeline
77
from datatrove.io import DataFolderLike, get_datafolder
@@ -31,6 +31,7 @@ def __init__(
3131
output_filename: str = None,
3232
compression: str | None = "infer",
3333
adapter: Callable = None,
34+
mode: str = "wt",
3435
):
3536
"""
3637
Base writer block to save data to disk.
@@ -47,7 +48,7 @@ def __init__(
4748
if self.compression == "gzip" and not output_filename.endswith(".gz"):
4849
output_filename += ".gz"
4950
self.output_filename = Template(output_filename)
50-
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
51+
self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression)
5152
self.adapter = adapter if adapter else _default_adapter
5253

5354
def __enter__(self):
@@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
8182
)
8283

8384
@abstractmethod
84-
def _write(self, document: dict, file_handler):
85+
def _write(self, document: dict, file_handler: IO, filename: str):
8586
"""
8687
Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler`
8788
Args:
8889
document: dictionary with the data to save
8990
file_handler: file_handler where it should be saved
90-
91+
filename: to use as a key for writer helpers and other data
9192
Returns:
9293
9394
"""
@@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs):
105106
106107
"""
107108
output_filename = self._get_output_filename(document, rank, **kwargs)
108-
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
109+
self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename)
109110
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
110111
self.stat_update(StatHints.total)
111112
self.update_doc_stats(document)

src/datatrove/pipeline/writers/jsonl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ def __init__(
1818
):
1919
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)
2020

21-
def _write(self, document: dict, file: IO):
22-
file.write(json.dumps(document, ensure_ascii=False) + "\n")
21+
def _write(self, document: dict, file_handler: IO, _filename: str):
22+
file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from collections import defaultdict
2+
from typing import IO, Callable
3+
4+
from datatrove.io import DataFolderLike
5+
from datatrove.pipeline.writers.disk_base import DiskWriter
6+
7+
8+
class ParquetWriter(DiskWriter):
9+
default_output_filename: str = "${rank}.parquet"
10+
name = "📒 Parquet"
11+
_requires_dependencies = ["pyarrow"]
12+
13+
def __init__(
14+
self,
15+
output_folder: DataFolderLike,
16+
output_filename: str = None,
17+
compression: str | None = None,
18+
adapter: Callable = None,
19+
batch_size: int = 1000,
20+
):
21+
super().__init__(output_folder, output_filename, compression, adapter, mode="wb")
22+
self._writers = {}
23+
self._batches = defaultdict(list)
24+
self.batch_size = batch_size
25+
26+
def _write_batch(self, filename):
27+
if not self._batches[filename]:
28+
return
29+
import pyarrow as pa
30+
31+
# prepare batch
32+
batch = pa.RecordBatch.from_pylist(self._batches.pop(filename))
33+
# write batch
34+
self._writers[filename].write_batch(batch)
35+
36+
def _write(self, document: dict, file_handler: IO, filename: str):
37+
import pyarrow as pa
38+
import pyarrow.parquet as pq
39+
40+
if filename not in self._writers:
41+
self._writers[filename] = pq.ParquetWriter(
42+
file_handler, schema=pa.RecordBatch.from_pylist([document]).schema
43+
)
44+
self._batches[filename].append(document)
45+
if len(self._batches[filename]) == self.batch_size:
46+
self._write_batch(filename)
47+
48+
def close(self):
49+
for filename in list(self._batches.keys()):
50+
self._write_batch(filename)
51+
for writer in self._writers.values():
52+
writer.close()
53+
self._batches.clear()
54+
self._writers.clear()
55+
super().close()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import shutil
2+
import tempfile
3+
import unittest
4+
5+
from datatrove.data import Document
6+
from datatrove.pipeline.readers.parquet import ParquetReader
7+
from datatrove.pipeline.writers.parquet import ParquetWriter
8+
9+
from ..utils import require_pyarrow
10+
11+
12+
@require_pyarrow
13+
class TestParquetWriter(unittest.TestCase):
14+
def setUp(self):
15+
# Create a temporary directory
16+
self.tmp_dir = tempfile.mkdtemp()
17+
self.addCleanup(shutil.rmtree, self.tmp_dir)
18+
19+
def test_write(self):
20+
data = [
21+
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"})
22+
for i, text in enumerate(["hello", "text2", "more text"])
23+
]
24+
with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w:
25+
for doc in data:
26+
w.write(doc)
27+
reader = ParquetReader(self.tmp_dir)
28+
c = 0
29+
for read_doc, original in zip(reader(), data):
30+
read_doc.metadata.pop("file_path", None)
31+
assert read_doc == original
32+
c += 1
33+
assert c == len(data)

0 commit comments

Comments
 (0)