diff --git a/operators/s3_to_redshift_operator.py b/operators/s3_to_redshift_operator.py index 194260b..a9ca31f 100644 --- a/operators/s3_to_redshift_operator.py +++ b/operators/s3_to_redshift_operator.py @@ -58,7 +58,8 @@ class S3ToRedshiftOperator(BaseOperator): :param incremental_key: *(optional)* The incremental key to compare new data against the destination table with. Only required if using a load_type of - "upsert". + "upsert". This may be either + a list or a string. :type incremental_key: string :param foreign_key: *(optional)* This specifies any foreign_keys in the table and which corresponding table @@ -239,20 +240,42 @@ def getS3Conn(): # and the primary key is the same. # (e.g. Source: {"id": 1, "updated_at": "2017-01-02 00:00:00"}; # Destination: {"id": 1, "updated_at": "2017-01-01 00:00:00"}) - - delete_sql = \ - ''' + if isinstance(self.primary_key, list): + where_pk = "" + for i, item in enumerate(self.primary_key): + where_pk += """ + "{rs_schema}"."{rs_table}"."{item}" = "{rs_schema}"."{rs_table}{rs_suffix}"."{item}" + """.format( + rs_schema=self.redshift_schema, + rs_table=self.table, + rs_suffix=self.temp_suffix, + item=item, + ) + if i != (len(self.primary_key) - 1): + where_pk += " AND " + else: + where_pk = '"{rs_schema}"."{rs_table}"."{rs_pk}" = "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}"'.format( + rs_schema=self.redshift_schema, + rs_table=self.table, + rs_pk=self.primary_key, + rs_suffix=self.temp_suffix, + rs_ik=self.incremental_key, + ) + + delete_sql = """ DELETE FROM "{rs_schema}"."{rs_table}" USING "{rs_schema}"."{rs_table}{rs_suffix}" - WHERE "{rs_schema}"."{rs_table}"."{rs_pk}" = - "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}" + WHERE {where_pk} AND "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}" >= "{rs_schema}"."{rs_table}"."{rs_ik}" - '''.format(rs_schema=self.redshift_schema, - rs_table=self.table, - rs_pk=self.primary_key, - rs_suffix=self.temp_suffix, - rs_ik=self.incremental_key) + """.format( + rs_schema=self.redshift_schema, + rs_table=self.table, + rs_pk=self.primary_key, + rs_suffix=self.temp_suffix, + rs_ik=self.incremental_key, + where_pk=where_pk, + ) # Delete records from the source table where the incremental_key # is greater than or equal to the incremental_key of the destination @@ -264,19 +287,20 @@ def getS3Conn(): # (e.g. Source: {"id": 1, "updated_at": "2017-01-01 00:00:00"}; # Destination: {"id": 1, "updated_at": "2017-01-02 00:00:00"}) - delete_confirm_sql = \ - ''' + delete_confirm_sql = """ DELETE FROM "{rs_schema}"."{rs_table}{rs_suffix}" USING "{rs_schema}"."{rs_table}" - WHERE "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_pk}" = - "{rs_schema}"."{rs_table}"."{rs_pk}" + WHERE {where_pk} AND "{rs_schema}"."{rs_table}"."{rs_ik}" >= "{rs_schema}"."{rs_table}{rs_suffix}"."{rs_ik}" - '''.format(rs_schema=self.redshift_schema, - rs_table=self.table, - rs_pk=self.primary_key, - rs_suffix=self.temp_suffix, - rs_ik=self.incremental_key) + """.format( + rs_schema=self.redshift_schema, + rs_table=self.table, + rs_pk=self.primary_key, + rs_suffix=self.temp_suffix, + rs_ik=self.incremental_key, + where_pk=where_pk, + ) append_sql = \ ''' @@ -371,7 +395,10 @@ def create_if_not_exists(self, schema, pg_hook, temp=False): sk = '' if self.primary_key: - pk = ', primary key("{0}")'.format(self.primary_key) + if isinstance(self.primary_key, str): + pk = ", primary key({0})".format(self.primary_key) + elif isinstance(self.primary_key, list): + pk = ", primary key({0})".format(", ".join(self.primary_key)) if self.foreign_key: if isinstance(self.foreign_key, list):