33
44from __future__ import annotations
55
6+ import os
67from typing import TYPE_CHECKING
78
89from onetl .base .base_file_df_connection import BaseFileDFConnection
9- from onetl .file import FileDFReader , FileDFWriter
10+ from onetl .file import FileDFReader , FileDFWriter , FileMover
11+ from onetl .file .filter import Glob
1012
1113from syncmaster .dto .connections import ConnectionDTO
1214from syncmaster .dto .transfers import FileTransferDTO
1719
1820
1921class FileHandler (Handler ):
20- connection : BaseFileDFConnection
22+ df_connection : BaseFileDFConnection
2123 connection_dto : ConnectionDTO
2224 transfer_dto : FileTransferDTO
2325 _operators = {
@@ -35,12 +37,20 @@ class FileHandler(Handler):
3537 "not_ilike" : "NOT ILIKE" ,
3638 "regexp" : "RLIKE" ,
3739 }
40+ _compressions = {
41+ "gzip" : "gz" ,
42+ "snappy" : "snappy" ,
43+ "zlib" : "zlib" ,
44+ "lz4" : "lz4" ,
45+ "bzip2" : "bz2" ,
46+ "deflate" : "deflate" ,
47+ }
3848
3949 def read (self ) -> DataFrame :
4050 from pyspark .sql .types import StructType
4151
4252 reader = FileDFReader (
43- connection = self .connection ,
53+ connection = self .df_connection ,
4454 format = self .transfer_dto .file_format ,
4555 source_path = self .transfer_dto .directory_path ,
4656 df_schema = StructType .fromJson (self .transfer_dto .df_schema ) if self .transfer_dto .df_schema else None ,
@@ -59,14 +69,61 @@ def read(self) -> DataFrame:
5969 return df
6070
6171 def write (self , df : DataFrame ) -> None :
62- writer = FileDFWriter (
63- connection = self .connection ,
64- format = self .transfer_dto .file_format ,
65- target_path = self .transfer_dto .directory_path ,
66- options = self .transfer_dto .options ,
72+ tmp_path = os .path .join (self .transfer_dto .directory_path , ".tmp" , str (self .run_dto .id ))
73+ try :
74+ writer = FileDFWriter (
75+ connection = self .df_connection ,
76+ format = self .transfer_dto .file_format ,
77+ target_path = tmp_path ,
78+ options = self .transfer_dto .options ,
79+ )
80+ writer .run (df = df )
81+
82+ self ._rename_files (tmp_path )
83+
84+ mover = FileMover (
85+ connection = self .file_connection ,
86+ source_path = tmp_path ,
87+ target_path = self .transfer_dto .directory_path ,
88+ filters = [Glob (f"*.{ self ._get_file_extension ()} " )],
89+ )
90+ mover .run ()
91+ finally :
92+ self .file_connection .remove_dir (tmp_path , recursive = True )
93+
94+ def _rename_files (self , tmp_path : str ) -> None :
95+ files = self .file_connection .list_dir (tmp_path )
96+
97+ for index , file_name in enumerate (files ):
98+ extension = self ._get_file_extension ()
99+ new_name = self ._get_file_name (str (index ), extension )
100+ old_path = os .path .join (tmp_path , file_name )
101+ new_path = os .path .join (tmp_path , new_name )
102+ self .file_connection .rename_file (old_path , new_path )
103+
104+ def _get_file_name (self , index : str , extension : str ) -> str :
105+ return self .transfer_dto .file_name_template .format (
106+ index = index ,
107+ extension = extension ,
108+ run_id = self .run_dto .id ,
109+ run_created_at = self .run_dto .created_at .strftime ("%Y_%m_%d_%H_%M_%S" ),
67110 )
68111
69- return writer .run (df = df )
112+ def _get_file_extension (self ) -> str :
113+ extension = self .transfer_dto .file_format .name
114+ compression = getattr (self .transfer_dto .file_format , "compression" , None )
115+ if not compression or compression == "none" :
116+ return extension
117+
118+ compression = self ._compressions [compression ]
119+
120+ if extension == "parquet" and compression == "lz4" :
121+ return "lz4hadoop.parquet"
122+
123+ if extension in ("parquet" , "orc" ):
124+ return f"{ compression } .{ extension } "
125+
126+ return f"{ extension } .{ compression } "
70127
71128 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
72129 expressions = []
0 commit comments