1+ from airflow .models import BaseOperator
2+
3+ class S3ToRedshiftOperator (BaseOperator ):
4+ """
5+ S3 -> Redshift via COPY Commands
6+ """
7+
8+ template_fields = ('s3_key' ,'copy_cmd' )
9+
10+ def __init__ (self ,
11+ s3_conn_id , s3_bucket , s3_key ,
12+ rs_conn_id , rs_schema , rs_table ,
13+ copy_cmd , load_type = 'append' ,
14+ join_key = None , incremental_key = None ,
15+ * args , ** kwargs ):
16+
17+ super ().__init__ (* args , ** kwargs )
18+ self .s3_conn_id = s3_conn_id
19+ self .s3_bucket = s3_bucket
20+ self .s3_key = s3_key
21+
22+ self .rs_conn_id = rs_conn_id
23+ self .rs_schema = rs_schema
24+ self .rs_table = rs_table
25+
26+ self .copy_cmd = copy_cmd
27+ self .load_type = load_type
28+ self .join_key = join_key
29+ self .incremental_key = incremental_key
30+
31+ # Used In Case of Upsert
32+ self .tmp_tbl = None
33+ self .tmp_schema = None
34+
35+ if self .load_type not in ["append" , "upsert" ]:
36+ raise Exception ('Please choose "append", "rebuild", or "upsert".' )
37+
38+ if self .load_type == 'upsert' and (self .join_key is None or self .incremental_key is None ):
39+ raise Exception ('Upserts require join_key and incremental_key to be specified' )
40+
41+ def drop_tbl_ddl (self , schema , tbl , if_exists = True ):
42+ base_drop = "DROP TABLE {if_exists} {schema}.{tbl}"
43+
44+ if_exists = 'if exists' if if_exists else ''
45+
46+ return base_drop .format (
47+ if_exists = if_exists ,
48+ schema = schema ,
49+ tbl = tbl
50+ )
51+
52+ def duplicate_tbl_schema (self , old_schema , old_tbl , new_tbl = None , new_schema = None ):
53+ new_tbl = new_tbl if new_tbl is not None else old_tbl
54+ new_schema = new_schema if new_schema is not None else old_schema
55+
56+ cmd = 'CREATE TABLE {new_schema}.{new_tbl}(LIKE {old_schema}.{old_tbl});'
57+
58+ # give new_tbl a unique name in case of more than one task running
59+ rand4 = '' .join ((choice (ascii_lowercase ) for i in range (4 )))
60+ new_tbl += '_tmp_' + rand4 if new_tbl == old_tbl else ''
61+
62+ self .tmp_tbl = new_tbl
63+ self .tmp_schema = new_schema
64+
65+ return cmd .format (
66+ new_schema = new_schema ,
67+ new_tbl = new_tbl ,
68+ old_schema = old_schema ,
69+ old_tbl = old_tbl
70+ )
71+
72+ def del_from_tbl_ddl (self , del_schema , del_tbl , join_schema , join_tbl , conditions = None ):
73+ delete = """DELETE FROM {src_schema}.{src_tbl} USING {join_schema}.{join_tbl} join_tbl"""
74+
75+ delete = delete .format (
76+ src_schema = del_schema ,
77+ src_tbl = del_tbl ,
78+ join_schema = join_schema ,
79+ join_tbl = join_tbl
80+ )
81+
82+ if conditions :
83+ delete += '\n WHERE '
84+ delete += '\n AND ' .join (conditions )
85+
86+ return delete
87+
88+ def insert_stg_into_dst_ddl (self , dst_schema , dst_tbl , stg_schema , stg_tbl ):
89+ insert = """insert into {dst_schema}.{dst_tbl}\n (select * from {stg_schema}.{stg_tbl});"""
90+
91+ return insert .format (
92+ dst_schema = dst_schema ,
93+ dst_tbl = dst_tbl ,
94+ stg_schema = stg_schema ,
95+ stg_tbl = stg_tbl
96+ )
97+
98+ def execute (self , context ):
99+ """
100+ Runs copy command on redshift
101+ """
102+ pg = PostgresHook (postgres_conn_id = self .rs_conn_id )
103+
104+ a_key , s_key = S3Hook (s3_conn_id = self .s3_conn_id ).get_credentials ()
105+ conn_str = 'aws_access_key_id={};aws_secret_access_key={}' .format (a_key , s_key )
106+
107+ # If append -> normal copy into table
108+ if self .load_type == 'append' :
109+ copy_cmd = self .copy_cmd .format (creds = conn_str , bucket = self .s3_bucket , key = self .s3_key )
110+ pg .run (copy_cmd )
111+
112+ else :
113+ # Duplicate Dst Tbl
114+ duplicate_tbl = self .duplicate_tbl_schema (self .rs_schema , self .rs_table )
115+ pg .run (duplicate_tbl )
116+
117+ copy_cmd = self .copy_cmd .format (creds = conn_str , bucket = self .s3_bucket , key = self .s3_key )
118+ pg .run (copy_cmd )
119+
120+ # DELETE Duplicate Rows
121+ del_conditions = [
122+ "{}.{} = join_tbl.{}" .format (
123+ self .rs_table ,
124+ self .join_key ,
125+ self .join_key
126+ ),
127+ "{}.{} < join_tbl.{}" .format (
128+ self .rs_table ,
129+ self .incremental_key ,
130+ self .incremental_key
131+ )
132+ ]
133+
134+ del_ddl = self .del_from_tbl_ddl (
135+ self .rs_schema ,
136+ self .rs_table ,
137+ self .tmp_schema ,
138+ self .tmp_tbl ,
139+ del_conditions ,
140+ )
141+ pg .run (del_ddl )
142+
143+ # Do Inserts
144+ insert_ddl = self .insert_stg_into_dst_ddl (
145+ self .rs_schema ,
146+ self .rs_table ,
147+ self .tmp_schema ,
148+ self .tmp_tbl ,
149+ )
150+ pg .run (insert_ddl )
151+
152+ # Cleanup Temp Table
153+ drop_ddl = self .drop_tbl_ddl (self .tmp_schema , self .tmp_tbl )
154+ pg .run (drop_ddl )
0 commit comments