11import dataclasses
22from abc import ABC , abstractmethod
33from string import Template
4- from typing import Callable
4+ from typing import IO , Callable
55
66from datatrove .data import Document , DocumentsPipeline
77from datatrove .io import DataFolderLike , get_datafolder
@@ -31,6 +31,7 @@ def __init__(
3131 output_filename : str = None ,
3232 compression : str | None = "infer" ,
3333 adapter : Callable = None ,
34+ mode : str = "wt" ,
3435 ):
3536 """
3637 Base writer block to save data to disk.
@@ -47,7 +48,7 @@ def __init__(
4748 if self .compression == "gzip" and not output_filename .endswith (".gz" ):
4849 output_filename += ".gz"
4950 self .output_filename = Template (output_filename )
50- self .output_mg = self .output_folder .get_output_file_manager (mode = "wt" , compression = compression )
51+ self .output_mg = self .output_folder .get_output_file_manager (mode = mode , compression = compression )
5152 self .adapter = adapter if adapter else _default_adapter
5253
5354 def __enter__ (self ):
@@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
8182 )
8283
8384 @abstractmethod
84- def _write (self , document : dict , file_handler ):
85+ def _write (self , document : dict , file_handler : IO , filename : str ):
8586 """
8687 Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler`
8788 Args:
8889 document: dictionary with the data to save
8990 file_handler: file_handler where it should be saved
90-
91+ filename: to use as a key for writer helpers and other data
9192 Returns:
9293
9394 """
@@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs):
105106
106107 """
107108 output_filename = self ._get_output_filename (document , rank , ** kwargs )
108- self ._write (self .adapter (document ), self .output_mg .get_file (output_filename ))
109+ self ._write (self .adapter (document ), self .output_mg .get_file (output_filename ), output_filename )
109110 self .stat_update (self ._get_output_filename (document , "XXXXX" , ** kwargs ))
110111 self .stat_update (StatHints .total )
111112 self .update_doc_stats (document )
0 commit comments