33
44from __future__ import annotations
55
6+ import os
67from typing import TYPE_CHECKING
78
89from onetl .base .base_file_df_connection import BaseFileDFConnection
9- from onetl .file import FileDFReader , FileDFWriter
10+ from onetl .file import FileDFReader , FileDFWriter , FileMover
1011
1112from syncmaster .dto .connections import ConnectionDTO
1213from syncmaster .dto .transfers import FileTransferDTO
1718
1819
1920class FileHandler (Handler ):
20- connection : BaseFileDFConnection
21+ df_connection : BaseFileDFConnection
2122 connection_dto : ConnectionDTO
2223 transfer_dto : FileTransferDTO
2324 _operators = {
@@ -40,7 +41,7 @@ def read(self) -> DataFrame:
4041 from pyspark .sql .types import StructType
4142
4243 reader = FileDFReader (
43- connection = self .connection ,
44+ connection = self .df_connection ,
4445 format = self .transfer_dto .file_format ,
4546 source_path = self .transfer_dto .directory_path ,
4647 df_schema = StructType .fromJson (self .transfer_dto .df_schema ) if self .transfer_dto .df_schema else None ,
@@ -59,14 +60,65 @@ def read(self) -> DataFrame:
5960 return df
6061
6162 def write (self , df : DataFrame ) -> None :
62- writer = FileDFWriter (
63- connection = self .connection ,
64- format = self .transfer_dto .file_format ,
65- target_path = self .transfer_dto .directory_path ,
66- options = self .transfer_dto .options ,
63+ tmp_path = os .path .join (self .transfer_dto .directory_path , ".tmp" , str (self .run_dto .id ))
64+ try :
65+ writer = FileDFWriter (
66+ connection = self .df_connection ,
67+ format = self .transfer_dto .file_format ,
68+ target_path = tmp_path ,
69+ options = self .transfer_dto .options ,
70+ )
71+ writer .run (df = df )
72+
73+ self ._rename_files (tmp_path )
74+
75+ mover = FileMover (
76+ connection = self .file_connection ,
77+ source_path = tmp_path ,
78+ target_path = self .transfer_dto .directory_path ,
79+ )
80+ mover .run ()
81+ finally :
82+ self .file_connection .remove_dir (tmp_path , recursive = True )
83+
84+ def _rename_files (self , tmp_path : str ) -> None :
85+ files = self .file_connection .list_dir (tmp_path )
86+
87+ for index , file_name in enumerate (files ):
88+ extension = self ._get_file_extension (str (file_name ))
89+ new_name = self ._get_file_name (str (index ), extension )
90+ old_path = os .path .join (tmp_path , file_name )
91+ new_path = os .path .join (tmp_path , new_name )
92+ self .file_connection .rename_file (old_path , new_path )
93+
94+ def _get_file_name (self , index : str , extension : str ) -> str :
95+ return self .transfer_dto .file_name_template .format (
96+ index = index ,
97+ extension = extension ,
98+ run_id = self .run_dto .id ,
99+ run_created_at = self .run_dto .created_at .strftime ("%Y_%m_%d_%H_%M_%S" ),
67100 )
68101
69- return writer .run (df = df )
102+ def _get_file_extension (self , file_name : str ) -> str :
103+ extension = self .transfer_dto .file_format .name
104+ parts = file_name .split ("." )
105+
106+ if extension == "xml" : # spark-xml does not write any extension to files
107+ if len (parts ) <= 1 :
108+ return extension
109+
110+ compression = parts [- 1 ]
111+
112+ else :
113+ if len (parts ) <= 2 :
114+ return extension
115+
116+ compression = parts [- 1 ] if parts [- 1 ] != extension else parts [- 2 ]
117+
118+ if extension in ("parquet" , "orc" ):
119+ return f"{ compression } .{ extension } "
120+
121+ return f"{ extension } .{ compression } "
70122
71123 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
72124 expressions = []
0 commit comments