Skip to content

Commit a3dc7bd

Browse files
committed
Update GX DQ script and inputs to include more tables and data loads
1 parent a80fbbd commit a3dc7bd

File tree

2 files changed

+126
-88
lines changed

2 files changed

+126
-88
lines changed
Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
1-
sql_config = {
2-
"properties_1a": {
3-
"sql": """ SELECT *
4-
FROM "housing_nec_migration"."properties_1a" """,
5-
"id_field": "LPRO_PROPREF",
6-
},
7-
"properties_1b": {
8-
"sql": """ SELECT *
9-
FROM "housing_nec_migration"."properties_1b" """,
10-
"id_field": "LPRO_PROPREF",
11-
},
12-
"properties_1c": {
13-
"sql": """ SELECT *
14-
FROM "housing_nec_migration"."properties_1c" """,
15-
"id_field": "LPRO_PROPREF",
16-
},
17-
}
1+
sql_config = {"properties": {"id_field": "LPRO_PROPREF"}}
182

3+
data_load_list = ["properties"]
194

20-
table_list = ["properties_1a", "properties_1b", "properties_1c"]
5+
table_list = {
6+
"properties": [
7+
"properties_1a",
8+
"properties_1b",
9+
"properties_1c",
10+
"properties_1d",
11+
"properties_1e",
12+
"properties_4a",
13+
"properties_4c",
14+
]
15+
}
2116

2217
partition_keys = ["import_date"]

scripts/jobs/housing/housing_nec_migration_apply_gx_dq_tests.py

Lines changed: 113 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,30 @@
1010
import great_expectations as gx
1111
import pandas as pd
1212
from pyathena import connect
13-
from scripts.helpers.housing_nec_migration_gx_dq_inputs import sql_config, table_list
13+
from scripts.helpers.housing_nec_migration_gx_dq_inputs import sql_config, data_load_list, table_list
1414
import scripts.jobs.housing.housing_nec_migration_properties_data_load_gx_suite
1515

1616
logging.basicConfig(level=logging.INFO)
1717
logger = logging.getLogger(__name__)
1818

19-
arg_keys = ['region_name', 's3_endpoint', 's3_target_location', 's3_staging_location', 'target_database',
20-
'target_table']
19+
arg_keys = [
20+
"region_name",
21+
"s3_endpoint",
22+
"s3_target_location",
23+
"s3_staging_location",
24+
"target_database",
25+
"target_table",
26+
]
2127
args = getResolvedOptions(sys.argv, arg_keys)
2228
locals().update(args)
2329

2430

31+
def get_sql_query(sql_config, data_load, table):
32+
query = f"SELECT * FROM housing_nec_migration.{table}"
33+
id_field = sql_config.get(data_load).get("id_field")
34+
return query, id_field
35+
36+
2537
def json_serial(obj):
2638
"""JSON serializer for objects not serializable by default."""
2739
if isinstance(obj, (datetime, date)):
@@ -35,88 +47,117 @@ def main():
3547

3648
table_results_df_list = []
3749

38-
for table in table_list:
39-
logger.info(f'{table} loading...')
40-
41-
sql_query = sql_config.get(table).get('sql')
50+
for data_load in data_load_list:
51+
logger.info(f'{data_load} loading...')
4252

43-
conn = connect(s3_staging_dir=s3_staging_location,
44-
region_name=region_name)
53+
for table in table_list.get(data_load):
54+
logger.info(f"{table} loading...")
4555

46-
df = pd.read_sql_query(sql_query, conn)
56+
sql_query, id_field = get_sql_query(sql_config=sql_config, data_load=data_load, table=table)
4757

48-
# set up batch
49-
data_source = context.data_sources.add_pandas(f'{table}_pandas')
50-
data_asset = data_source.add_dataframe_asset(name=f'{table}_df_asset')
51-
batch_definition = data_asset.add_batch_definition_whole_dataframe("Athena batch definition")
52-
batch_parameters = {"dataframe": df}
58+
conn = connect(s3_staging_dir=s3_staging_location, region_name=region_name)
5359

54-
# get expectation suite for dataset
55-
suite = context.suites.get(name='properties_data_load_suite')
60+
df = pd.read_sql_query(sql_query, conn)
5661

57-
validation_definition = gx.ValidationDefinition(
58-
data=batch_definition,
59-
suite=suite,
60-
name=f'validation_definition_{table}')
61-
validation_definition = context.validation_definitions.add(validation_definition)
62-
63-
# create and start checking data with checkpoints
64-
checkpoint = context.checkpoints.add(
65-
gx.checkpoint.checkpoint.Checkpoint(
66-
name=f'{table}_checkpoint',
67-
validation_definitions=[validation_definition],
68-
result_format={"result_format": "COMPLETE",
69-
"return_unexpected_index_query": False,
70-
"partial_unexpected_count": 0}
62+
# set up batch
63+
data_source = context.data_sources.add_pandas(f"{table}_pandas")
64+
data_asset = data_source.add_dataframe_asset(name=f"{table}_df_asset")
65+
batch_definition = data_asset.add_batch_definition_whole_dataframe(
66+
"Athena batch definition"
7167
)
72-
)
68+
batch_parameters = {"dataframe": df}
7369

74-
checkpoint_result = checkpoint.run(batch_parameters=batch_parameters)
75-
results_dict = list(checkpoint_result.run_results.values())[0].to_json_dict()
76-
table_results_df = pd.json_normalize(results_dict['results'])
77-
cols_not_needed = ['result.unexpected_list', 'result.observed_value']
78-
cols_to_drop = [c for c in table_results_df.columns if c.startswith('exception_info') or c in cols_not_needed]
70+
# get expectation suite for dataset
71+
suite = context.suites.get(name=f"{data_load}_data_load_suite")
7972

80-
table_results_df = table_results_df.drop(columns=cols_to_drop)
81-
table_results_df_list.append(table_results_df)
73+
validation_definition = gx.ValidationDefinition(
74+
data=batch_definition, suite=suite, name=f"validation_definition_{table}"
75+
)
76+
validation_definition = context.validation_definitions.add(
77+
validation_definition
78+
)
8279

83-
# generate id lists for each unexpected result set
84-
query_df = table_results_df.loc[(~table_results_df['result.unexpected_index_list'].isna()) & (
85-
table_results_df['result.unexpected_index_list'].values != '[]')]
80+
# create and start checking data with checkpoints
81+
checkpoint = context.checkpoints.add(
82+
gx.checkpoint.checkpoint.Checkpoint(
83+
name=f"{table}_checkpoint",
84+
validation_definitions=[validation_definition],
85+
result_format={
86+
"result_format": "COMPLETE",
87+
"return_unexpected_index_query": False,
88+
"partial_unexpected_count": 0,
89+
},
90+
)
91+
)
8692

87-
table_results_df['unexpected_id_list'] = pd.Series(dtype='object')
88-
for i, row in query_df.iterrows():
89-
table_results_df.loc[i, 'unexpected_id_list'] = str(
90-
list(df[sql_config.get(table).get('id_field')].iloc[row['result.unexpected_index_list']]))
93+
checkpoint_result = checkpoint.run(batch_parameters=batch_parameters)
94+
results_dict = list(checkpoint_result.run_results.values())[0].to_json_dict()
95+
table_results_df = pd.json_normalize(results_dict["results"])
96+
cols_not_needed = ["result.unexpected_list", "result.observed_value"]
97+
cols_to_drop = [
98+
c
99+
for c in table_results_df.columns
100+
if c.startswith("exception_info") or c in cols_not_needed
101+
]
102+
103+
table_results_df = table_results_df.drop(columns=cols_to_drop)
104+
table_results_df_list.append(table_results_df)
105+
106+
# generate id lists for each unexpected result set
107+
query_df = table_results_df.loc[
108+
(~table_results_df["result.unexpected_index_list"].isna())
109+
& (table_results_df["result.unexpected_index_list"].values != "[]")
110+
]
111+
112+
table_results_df["unexpected_id_list"] = pd.Series(dtype="object")
113+
for i, row in query_df.iterrows():
114+
table_results_df.loc[i, "unexpected_id_list"] = str(
115+
list(
116+
df[id_field].iloc[
117+
row["result.unexpected_index_list"]
118+
]
119+
)
120+
)
91121

92122
results_df = pd.concat(table_results_df_list)
93123

94124
# add clean dataset name
95-
results_df['dataset_name'] = results_df['expectation_config.kwargs.batch_id'].map(
96-
lambda x: x.removeprefix('pandas-').removesuffix('_df_asset'))
125+
results_df["dataset_name"] = results_df["expectation_config.kwargs.batch_id"].map(
126+
lambda x: x.removeprefix("pandas-").removesuffix("_df_asset")
127+
)
97128

98129
# add composite key for each specific test (so can be tracked over time)
99-
results_df.insert(loc=0, column='expectation_key',
100-
value=results_df.set_index(['expectation_config.type', 'dataset_name']).index.factorize()[0] + 1)
101-
results_df['expectation_id'] = results_df['expectation_config.type'] + "_" + results_df['dataset_name']
102-
results_df['import_date'] = datetime.today().strftime('%Y%m%d')
130+
results_df.insert(
131+
loc=0,
132+
column="expectation_key",
133+
value=results_df.set_index(
134+
["expectation_config.type", "dataset_name"]
135+
).index.factorize()[0]
136+
+ 1,
137+
)
138+
results_df["expectation_id"] = (
139+
results_df["expectation_config.type"] + "_" + results_df["dataset_name"]
140+
)
141+
results_df["import_date"] = datetime.today().strftime("%Y%m%d")
103142

104143
# set dtypes for Athena
105-
dtype_dict = {'expectation_config.type': 'string',
106-
'expectation_config.kwargs.batch_id': 'string',
107-
'expectation_config.kwargs.column': 'string',
108-
'expectation_config.kwargs.min_value': 'string',
109-
'expectation_config.kwargs.max_value': 'string',
110-
'result.element_count': 'bigint',
111-
'result.unexpected_count': 'bigint',
112-
'result.missing_count': 'bigint',
113-
'result.partial_unexpected_list': 'array<string>',
114-
'result.unexpected_index_list': 'array<bigint>',
115-
'result.unexpected_index_query': 'string',
116-
'expectation_config.kwargs.regex': 'string',
117-
'expectation_config.kwargs.value_set': 'string',
118-
'expectation_config.kwargs.column_list': 'string',
119-
'import_date': 'string'}
144+
dtype_dict = {
145+
"expectation_config.type": "string",
146+
"expectation_config.kwargs.batch_id": "string",
147+
"expectation_config.kwargs.column": "string",
148+
"expectation_config.kwargs.min_value": "string",
149+
"expectation_config.kwargs.max_value": "string",
150+
"result.element_count": "bigint",
151+
"result.unexpected_count": "bigint",
152+
"result.missing_count": "bigint",
153+
"result.partial_unexpected_list": "array<string>",
154+
"result.unexpected_index_list": "array<bigint>",
155+
"result.unexpected_index_query": "string",
156+
"expectation_config.kwargs.regex": "string",
157+
"expectation_config.kwargs.value_set": "string",
158+
"expectation_config.kwargs.column_list": "string",
159+
"import_date": "string",
160+
}
120161

121162
# write to s3
122163
wr.s3.to_parquet(
@@ -127,11 +168,13 @@ def main():
127168
table=target_table,
128169
mode="overwrite",
129170
dtype=dtype_dict,
130-
schema_evolution=True
171+
schema_evolution=True,
131172
)
132173

133-
logger.info(f'Data Quality test results for NEC data loads written to {s3_target_location}')
174+
logger.info(
175+
f"Data Quality test results for NEC data loads written to {s3_target_location}"
176+
)
134177

135178

136-
if __name__ == '__main__':
179+
if __name__ == "__main__":
137180
main()

0 commit comments

Comments
 (0)