33
44from __future__ import annotations
55
6+ import os
67from typing import TYPE_CHECKING
78
89from onetl .base .base_file_df_connection import BaseFileDFConnection
9- from onetl .file import FileDFReader , FileDFWriter
10+ from onetl .file import FileDFReader , FileDFWriter , FileMover
11+ from onetl .file .filter import Glob
1012
1113from syncmaster .dto .connections import ConnectionDTO
1214from syncmaster .dto .transfers import FileTransferDTO
1517if TYPE_CHECKING :
1618 from pyspark .sql .dataframe import DataFrame
1719
20+ COLUMN_FORMATS = ("parquet" , "orc" )
21+
1822
1923class FileHandler (Handler ):
20- connection : BaseFileDFConnection
24+ """
25+ TODO: FileHandler is actually handler for FileDFWriter with remote FS (direct write).
26+ FileProtocolHandler is handler for FileDFWriter with local FS (write via upload).
27+ Maybe we should keep here only common methods,
28+ like file name generator and split other ones to classes where the method is really used.
29+ """
30+
31+ df_connection : BaseFileDFConnection
2132 connection_dto : ConnectionDTO
2233 transfer_dto : FileTransferDTO
2334 _operators = {
@@ -35,12 +46,29 @@ class FileHandler(Handler):
3546 "not_ilike" : "NOT ILIKE" ,
3647 "regexp" : "RLIKE" ,
3748 }
49+ _compression_to_file_suffix = {
50+ "gzip" : "gz" ,
51+ "snappy" : "snappy" ,
52+ "zlib" : "zlib" ,
53+ "lz4" : "lz4" ,
54+ "bzip2" : "bz2" ,
55+ "deflate" : "deflate" ,
56+ }
57+ _file_format_to_file_suffix = {
58+ "json" : "json" ,
59+ "jsonline" : "jsonl" ,
60+ "csv" : "csv" ,
61+ "xml" : "xml" ,
62+ "excel" : "xlsx" ,
63+ "parquet" : "parquet" ,
64+ "orc" : "orc" ,
65+ }
3866
3967 def read (self ) -> DataFrame :
4068 from pyspark .sql .types import StructType
4169
4270 reader = FileDFReader (
43- connection = self .connection ,
71+ connection = self .df_connection ,
4472 format = self .transfer_dto .file_format ,
4573 source_path = self .transfer_dto .directory_path ,
4674 df_schema = StructType .fromJson (self .transfer_dto .df_schema ) if self .transfer_dto .df_schema else None ,
@@ -59,14 +87,65 @@ def read(self) -> DataFrame:
5987 return df
6088
6189 def write (self , df : DataFrame ) -> None :
62- writer = FileDFWriter (
63- connection = self .connection ,
64- format = self .transfer_dto .file_format ,
65- target_path = self .transfer_dto .directory_path ,
66- options = self .transfer_dto .options ,
90+ tmp_path = os .path .join (self .transfer_dto .directory_path , ".tmp" , str (self .run_dto .id ))
91+ try :
92+ writer = FileDFWriter (
93+ connection = self .df_connection ,
94+ format = self .transfer_dto .file_format ,
95+ target_path = tmp_path ,
96+ options = self .transfer_dto .options ,
97+ )
98+ writer .run (df = df )
99+
100+ self ._rename_files (tmp_path )
101+
102+ mover = FileMover (
103+ connection = self .file_connection ,
104+ source_path = tmp_path ,
105+ target_path = self .transfer_dto .directory_path ,
106+ # ignore .crc and other metadata files
107+ filters = [Glob (f"*.{ self ._get_file_extension ()} " )],
108+ )
109+ mover .run ()
110+ finally :
111+ self .file_connection .remove_dir (tmp_path , recursive = True )
112+
113+ def _rename_files (self , tmp_path : str ) -> None :
114+ files = self .file_connection .list_dir (tmp_path )
115+
116+ for index , file_name in enumerate (files ):
117+ extension = self ._get_file_extension ()
118+ new_name = self ._get_file_name (str (index ), extension )
119+ old_path = os .path .join (tmp_path , file_name )
120+ new_path = os .path .join (tmp_path , new_name )
121+ self .file_connection .rename_file (old_path , new_path )
122+
123+ def _get_file_name (self , index : str , extension : str ) -> str :
124+ return self .transfer_dto .file_name_template .format (
125+ index = index ,
126+ extension = extension ,
127+ run_id = self .run_dto .id ,
128+ run_created_at = self .run_dto .created_at .strftime ("%Y_%m_%d_%H_%M_%S" ),
67129 )
68130
69- return writer .run (df = df )
131+ def _get_file_extension (self ) -> str :
132+ file_format = self .transfer_dto .file_format .__class__ .__name__ .lower ()
133+ extension_suffix = self ._file_format_to_file_suffix [file_format ]
134+
135+ compression = getattr (self .transfer_dto .file_format , "compression" , "none" )
136+ if compression == "none" :
137+ return extension_suffix
138+
139+ compression_suffix = self ._compression_to_file_suffix [compression ]
140+
141+ # https://github.com/apache/parquet-java/blob/fb6f0be0323f5f52715b54b8c6602763d8d0128d/parquet-common/src/main/java/org/apache/parquet/hadoop/metadata/CompressionCodecName.java#L26-L33
142+ if extension_suffix == "parquet" and compression_suffix == "lz4" :
143+ return "lz4hadoop.parquet"
144+
145+ if extension_suffix in COLUMN_FORMATS :
146+ return f"{ compression_suffix } .{ extension_suffix } "
147+
148+ return f"{ extension_suffix } .{ compression_suffix } "
70149
71150 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
72151 expressions = []
0 commit comments