Skip to content

Commit 913790a

Browse files
add s3 to mysql
1 parent 46583d5 commit 913790a

File tree

2 files changed

+241
-1
lines changed

2 files changed

+241
-1
lines changed

__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from airflow.plugins_manager import AirflowPlugin
22
from mysql_plugin.hooks.astro_mysql_hook import AstroMySqlHook
33
from mysql_plugin.operators.mysql_to_s3_operator import MySQLToS3Operator
4+
from mysql_plugin.operators.s3_to_mysql_operator import S3ToMySQLOperator
45

56

67
class MySQLToS3Plugin(AirflowPlugin):
78
name = "MySQLToS3Plugin"
8-
operators = [MySQLToS3Operator]
9+
operators = [MySQLToS3Operator, S3ToMySQLOperator]
910
# Leave in for explicitness
1011
hooks = [AstroMySqlHook]
1112
executors = []

operators/s3_to_mysql_operator.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
from airflow.models import BaseOperator
2+
from airflow.hooks.S3_hook import S3Hook
3+
from airflow.hooks.mysql_hook import MySqlHook
4+
import dateutil.parser
5+
import json
6+
import logging
7+
8+
9+
class S3ToMySQLOperator(BaseOperator):
10+
"""
11+
MySQL to Spreadsheet Operator
12+
13+
NOTE: To avoid invalid characters, it is recommended
14+
to specify the character encoding (e.g {"charset":"utf8"}).
15+
16+
S3 To MySQL Operator
17+
:param s3_conn_id: The source s3 connection id.
18+
:type s3_conn_id: string
19+
:param s3_bucket: The source s3 bucket.
20+
:type s3_bucket: string
21+
:param s3_key: The source s3 key.
22+
:type s3_key: string
23+
:param mysql_conn_id: The destination redshift connection id.
24+
:type mysql_conn_id: string
25+
:param database: The destination database name.
26+
:type database: string
27+
:param table: The destination mysql table name.
28+
:type table: string
29+
:param field_schema: An array of dicts in the following format:
30+
{'name': 'column_name', 'type': 'int(11)'}
31+
which determine what fields will be created
32+
and inserted.
33+
:type field_schema: array
34+
:param primary_key: The primary key for the
35+
destination table. Multiple strings in the
36+
array signify a compound key.
37+
:type primary_key: array
38+
:param incremental_key: *(optional)* The incremental key to compare
39+
new data against the destination table
40+
with. Only required if using a load_type of
41+
"upsert".
42+
:type incremental_key: string
43+
:param load_type: The method of loading into Redshift that
44+
should occur. Options are "append",
45+
"rebuild", and "upsert". Defaults to
46+
"append."
47+
:type load_type: string
48+
"""
49+
50+
template_fields = ('s3_key',)
51+
52+
def __init__(self,
53+
s3_conn_id,
54+
s3_bucket,
55+
s3_key,
56+
mysql_conn_id,
57+
database,
58+
table,
59+
field_schema,
60+
primary_key=[],
61+
incremental_key=None,
62+
load_type='append',
63+
*args,
64+
**kwargs):
65+
super().__init__(*args, **kwargs)
66+
67+
self.mysql_conn_id = mysql_conn_id
68+
self.s3_conn_id = s3_conn_id
69+
self.s3_bucket = s3_bucket
70+
self.s3_key = s3_key
71+
self.table = table
72+
self.database = database
73+
self.field_schema = field_schema
74+
self.primary_key = primary_key
75+
self.incremental_key = incremental_key
76+
self.load_type = load_type
77+
78+
def execute(self, context):
79+
m_hook = MySqlHook(self.mysql_conn_id)
80+
81+
data = (S3Hook(self.s3_conn_id)
82+
.get_key(self.s3_key, bucket_name=self.s3_bucket)
83+
.get_contents_as_string(encoding='utf-8'))
84+
85+
self.copy_data(m_hook, data)
86+
87+
def copy_data(self, m_hook, data):
88+
if self.load_type == 'rebuild':
89+
drop_query = \
90+
"""
91+
DROP TABLE IF EXISTS {schema}.{table}
92+
""".format(schema=self.database, table=self.table)
93+
m_hook.run(drop_query)
94+
95+
table_exists_query = \
96+
"""
97+
SELECT *
98+
FROM information_schema.tables
99+
WHERE table_schema = '{database}' AND table_name = '{table}'
100+
""".format(database=self.database, table=self.table)
101+
102+
if not m_hook.get_records(table_exists_query):
103+
self.create_table(m_hook)
104+
else:
105+
self.reconcile_schemas(m_hook)
106+
107+
self.write_data(m_hook, data)
108+
109+
def create_table(self, m_hook):
110+
# Fields are surround by `` in order to avoid namespace conflicts
111+
# with reserved words in MySQL.
112+
# https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
113+
114+
fields = ['`{name}` {type} {nullable}'.format(name=field['name'],
115+
type=field['type'],
116+
nullable='NOT NULL'
117+
if field['name']
118+
in self.primary_key
119+
else 'NULL')
120+
for field in self.field_schema]
121+
122+
keys = ', '.join(self.primary_key)
123+
124+
create_query = \
125+
"""
126+
CREATE TABLE IF NOT EXISTS {schema}.{table} ({fields}
127+
""".format(schema=self.database,
128+
table=self.table,
129+
fields=', '.join(fields))
130+
if keys:
131+
create_query += ', PRIMARY KEY (`{keys}`)'.format(keys=keys)
132+
133+
create_query += ')'
134+
135+
m_hook.run(create_query)
136+
137+
def reconcile_schemas(self, m_hook):
138+
describe_query = 'DESCRIBE {schema}.{table}'.format(schema=self.database,
139+
table=self.table)
140+
records = m_hook.get_records(describe_query)
141+
existing_columns_names = [x[0] for x in records]
142+
incoming_column_names = [field['name'] for field in self.field_schema]
143+
missing_columns = list(set(incoming_column_names) -
144+
set(existing_columns_names))
145+
if len(missing_columns):
146+
columns = ['ADD COLUMN {name} {type} NULL'.format(name=field['name'],
147+
type=field['type'])
148+
for field in self.field_schema
149+
if field['name'] in missing_columns]
150+
151+
alter_query = \
152+
"""
153+
ALTER TABLE {schema}.{table} {columns}
154+
""".format(schema=self.database,
155+
table=self.table,
156+
columns=', '.join(columns))
157+
158+
m_hook.run(alter_query)
159+
logging.info('The new columns were:' + str(missing_columns))
160+
else:
161+
logging.info('There were no new columns.')
162+
163+
def write_data(self, m_hook, data):
164+
fields = ', '.join([field['name'] for field in self.field_schema])
165+
166+
placeholders = ', '.join('%({name})s'.format(name=field['name'])
167+
for field in self.field_schema)
168+
169+
insert_query = \
170+
"""
171+
INSERT INTO {schema}.{table} ({columns})
172+
VALUES ({placeholders})
173+
""".format(schema=self.database,
174+
table=self.table,
175+
columns=fields,
176+
placeholders=placeholders)
177+
178+
if self.load_type == 'upsert':
179+
# Add IF check to ensure that the records being inserted have an
180+
# incremental_key with a value greater than the existing records.
181+
update_set = ', '.join(["""
182+
{name} = IF({ik} < VALUES({ik}),
183+
VALUES({name}), {name})
184+
""".format(name=field['name'],
185+
ik=self.incremental_key)
186+
for field in self.field_schema])
187+
188+
insert_query += ('ON DUPLICATE KEY UPDATE {update_set}'
189+
.format(update_set=update_set))
190+
191+
# Split the incoming JSON newlines string along new lines.
192+
# Remove cases where two or more '\n' results in empty entries.
193+
records = [record for record in data.split('\n') if record]
194+
195+
# Create a default "record" object with all available fields
196+
# intialized to None. These will be overwritten with the proper
197+
# field values as available.
198+
199+
default_object = {}
200+
201+
for field in self.field_schema:
202+
default_object[field['name']] = None
203+
204+
# Initialize null to Nonetype for incoming null values in records dict
205+
null = None
206+
output = []
207+
208+
for record in records:
209+
line_object = default_object.copy()
210+
line_object.update(json.loads(record))
211+
output.append(line_object)
212+
213+
date_fields = [field['name'] for field in self.field_schema if field['type'] in ['datetime', 'date']]
214+
215+
def convert_timestamps(key, value):
216+
if key in date_fields:
217+
try:
218+
# Parse strings to look for values that match a timestamp
219+
# and convert to datetime.
220+
# Set ignoretz=False to keep timezones embedded in datetime.
221+
# http://bit.ly/2zwcebe
222+
value = dateutil.parser.parse(value, ignoretz=False)
223+
return value
224+
except (ValueError, TypeError, OverflowError):
225+
# If the value does not match a timestamp or is null,
226+
# return intial value.
227+
return value
228+
else:
229+
return value
230+
231+
output = [dict([k, convert_timestamps(k, v)] if v is not None else [k, v]
232+
for k, v in i.items()) for i in output]
233+
234+
conn = m_hook.get_conn()
235+
cur = conn.cursor()
236+
cur.executemany(insert_query, output)
237+
cur.close()
238+
conn.commit()
239+
conn.close()

0 commit comments

Comments
 (0)