33
44from __future__ import annotations
55
6+ import os
7+ import shutil
68from typing import TYPE_CHECKING
79
810from onetl .base .base_file_df_connection import BaseFileDFConnection
9- from onetl .file import FileDFReader , FileDFWriter
11+ from onetl .file import FileDFReader , FileDFWriter , FileMover
1012
1113from syncmaster .dto .connections import ConnectionDTO
1214from syncmaster .dto .transfers import FileTransferDTO
1719
1820
1921class FileHandler (Handler ):
20- connection : BaseFileDFConnection
22+ df_connection : BaseFileDFConnection
2123 connection_dto : ConnectionDTO
2224 transfer_dto : FileTransferDTO
2325 _operators = {
@@ -40,7 +42,7 @@ def read(self) -> DataFrame:
4042 from pyspark .sql .types import StructType
4143
4244 reader = FileDFReader (
43- connection = self .connection ,
45+ connection = self .df_connection ,
4446 format = self .transfer_dto .file_format ,
4547 source_path = self .transfer_dto .directory_path ,
4648 df_schema = StructType .fromJson (self .transfer_dto .df_schema ) if self .transfer_dto .df_schema else None ,
@@ -59,14 +61,65 @@ def read(self) -> DataFrame:
5961 return df
6062
6163 def write (self , df : DataFrame ) -> None :
62- writer = FileDFWriter (
63- connection = self .connection ,
64- format = self .transfer_dto .file_format ,
65- target_path = self .transfer_dto .directory_path ,
66- options = self .transfer_dto .options ,
64+ tmp_path = os .path .join (self .transfer_dto .directory_path , ".tmp" , str (self .run_dto .id ))
65+ try :
66+ writer = FileDFWriter (
67+ connection = self .df_connection ,
68+ format = self .transfer_dto .file_format ,
69+ target_path = tmp_path ,
70+ options = self .transfer_dto .options ,
71+ )
72+ writer .run (df = df )
73+
74+ self ._rename_files (tmp_path )
75+
76+ mover = FileMover (
77+ connection = self .file_connection ,
78+ source_path = tmp_path ,
79+ target_path = self .transfer_dto .directory_path ,
80+ )
81+ mover .run ()
82+ finally :
83+ shutil .rmtree (tmp_path , ignore_errors = True )
84+
85+ def _rename_files (self , tmp_path : str ) -> None :
86+ files = self .file_connection .list_dir (tmp_path )
87+
88+ for index , file_name in enumerate (files ):
89+ extension = self ._get_file_extension (str (file_name ))
90+ new_name = self ._get_file_name (str (index ), extension )
91+ old_path = os .path .join (tmp_path , file_name )
92+ new_path = os .path .join (tmp_path , new_name )
93+ self .file_connection .rename_file (old_path , new_path )
94+
95+ def _get_file_name (self , index : str , extension : str ) -> str :
96+ return self .transfer_dto .file_name_template .format (
97+ index = index ,
98+ extension = extension ,
99+ run_id = self .run_dto .id ,
100+ run_created_at = self .run_dto .created_at .strftime ("%Y_%m_%d_%H_%M_%S" ),
67101 )
68102
69- return writer .run (df = df )
103+ def _get_file_extension (self , file_name : str ) -> str :
104+ extension = self .transfer_dto .file_format .name
105+ parts = file_name .split ("." )
106+
107+ if extension == "xml" : # spark-xml does not write any extension to files
108+ if len (parts ) <= 1 :
109+ return extension
110+
111+ compression = parts [- 1 ]
112+
113+ else :
114+ if len (parts ) <= 2 :
115+ return extension
116+
117+ compression = parts [- 1 ] if parts [- 1 ] != extension else parts [- 2 ]
118+
119+ if extension in ("parquet" , "orc" ):
120+ return f"{ compression } .{ extension } "
121+
122+ return f"{ extension } .{ compression } "
70123
71124 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
72125 expressions = []
0 commit comments