33
44from __future__ import annotations
55
6+ import os
67from typing import TYPE_CHECKING
78
89from onetl .base .base_file_df_connection import BaseFileDFConnection
9- from onetl .file import FileDFReader , FileDFWriter
10+ from onetl .file import FileDFReader , FileDFWriter , FileMover
1011
1112from syncmaster .dto .connections import ConnectionDTO
1213from syncmaster .dto .transfers import FileTransferDTO
1718
1819
1920class FileHandler (Handler ):
20- connection : BaseFileDFConnection
21+ df_connection : BaseFileDFConnection
2122 connection_dto : ConnectionDTO
2223 transfer_dto : FileTransferDTO
2324 _operators = {
@@ -40,7 +41,7 @@ def read(self) -> DataFrame:
4041 from pyspark .sql .types import StructType
4142
4243 reader = FileDFReader (
43- connection = self .connection ,
44+ connection = self .df_connection ,
4445 format = self .transfer_dto .file_format ,
4546 source_path = self .transfer_dto .directory_path ,
4647 df_schema = StructType .fromJson (self .transfer_dto .df_schema ) if self .transfer_dto .df_schema else None ,
@@ -59,14 +60,74 @@ def read(self) -> DataFrame:
5960 return df
6061
6162 def write (self , df : DataFrame ) -> None :
63+ tmp_path = os .path .join (self .transfer_dto .directory_path , ".tmp" , str (self .run_dto .id ))
6264 writer = FileDFWriter (
63- connection = self .connection ,
65+ connection = self .df_connection ,
6466 format = self .transfer_dto .file_format ,
65- target_path = self . transfer_dto . directory_path ,
67+ target_path = tmp_path ,
6668 options = self .transfer_dto .options ,
6769 )
70+ writer .run (df = df )
71+
72+ self ._rename_files (tmp_path )
73+
74+ mover = FileMover (
75+ connection = self .connection ,
76+ source_path = tmp_path ,
77+ target_path = self .transfer_dto .directory_path ,
78+ )
79+ mover .run ()
80+
81+ def _rename_files (self , tmp_path : str ) -> None :
82+ files = self .connection .list_dir (tmp_path )
83+
84+ for index , file_name in enumerate (files ):
85+ extension = self ._get_file_extension (str (file_name ))
86+ new_name = self ._get_file_name (str (index ), extension )
87+ old_path = os .path .join (tmp_path , file_name )
88+ new_path = os .path .join (tmp_path , new_name )
89+ self .connection .rename_file (old_path , new_path )
90+
91+ def _get_file_name (self , index : str , extension : str ) -> str :
92+ return (
93+ self .transfer_dto .file_name_template .replace (
94+ "{index}" ,
95+ index ,
96+ )
97+ .replace (
98+ "{extension}" ,
99+ extension ,
100+ )
101+ .replace (
102+ "{run_id}" ,
103+ str (self .run_dto .id ),
104+ )
105+ .replace (
106+ "{run_created_at}" ,
107+ self .run_dto .created_at .strftime ("%Y_%m_%d_%H_%M_%S" ),
108+ )
109+ )
110+
111+ def _get_file_extension (self , file_name : str ) -> str :
112+ extension = self .transfer_dto .file_format .name
113+ parts = file_name .split ("." )
114+
115+ if extension == "xml" : # spark-xml does not write any extension to files
116+ if len (parts ) <= 1 :
117+ return extension
118+
119+ compression = parts [- 1 ]
120+
121+ else :
122+ if len (parts ) <= 2 :
123+ return extension
124+
125+ compression = parts [- 1 ] if parts [- 1 ] != extension else parts [- 2 ]
126+
127+ if extension in ("parquet" , "orc" ):
128+ return f"{ compression } .{ extension } "
68129
69- return writer . run ( df = df )
130+ return f" { extension } . { compression } "
70131
71132 def _make_rows_filter_expression (self , filters : list [dict ]) -> str | None :
72133 expressions = []
0 commit comments