Skip to content

Commit 93696bb

Browse files
authored
(enhancement): Extend get and update ruleset DQ methods (#1882)
* (enhancement): Extend get and update ruleset DQ methods
1 parent 3ec0145 commit 93696bb

File tree

5 files changed

+579
-90
lines changed

5 files changed

+579
-90
lines changed

awswrangler/data_quality/_create.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from awswrangler import _utils, exceptions
1212
from awswrangler._config import apply_configs
13+
from awswrangler.data_quality._get import get_ruleset
1314
from awswrangler.data_quality._utils import (
1415
_create_datasource,
1516
_get_data_quality_results,
@@ -27,7 +28,7 @@ def _create_dqdl(
2728
"""Create DQDL from pandas data frame."""
2829
rules = []
2930
for rule_type, parameter, expression in df_rules.itertuples(index=False):
30-
parameter_str = f' "{parameter}" ' if parameter else " "
31+
parameter_str = f" {parameter} " if parameter else " "
3132
expression_str = expression if expression else ""
3233
rules.append(f"{rule_type}{parameter_str}{expression_str}")
3334
return "Rules = [ " + ", ".join(rules) + " ]"
@@ -85,7 +86,7 @@ def create_ruleset(
8586
>>> df = pd.DataFrame({"c0": [0, 1, 2], "c1": [0, 1, 2], "c2": [0, 0, 1]})
8687
>>> df_rules = pd.DataFrame({
8788
>>> "rule_type": ["RowCount", "IsComplete", "Uniqueness"],
88-
>>> "parameter": [None, "c0", "c0"],
89+
>>> "parameter": [None, '"c0"', '"c0"'],
8990
>>> "expression": ["between 1 and 6", None, "> 0.95"],
9091
>>> })
9192
>>> wr.s3.to_parquet(df, path, dataset=True, database="database", table="table")
@@ -121,6 +122,7 @@ def create_ruleset(
121122
def update_ruleset(
122123
name: str,
123124
updated_name: Optional[str] = None,
125+
mode: str = "overwrite",
124126
df_rules: Optional[pd.DataFrame] = None,
125127
dqdl_rules: Optional[str] = None,
126128
description: str = "",
@@ -134,6 +136,8 @@ def update_ruleset(
134136
Ruleset name.
135137
updated_name : str
136138
New ruleset name if renaming an existing ruleset.
139+
mode : str
140+
overwrite (default) or upsert.
137141
df_rules : str, optional
138142
Data frame with `rule_type`, `parameter`, and `expression` columns.
139143
dqdl_rules : str, optional
@@ -145,25 +149,46 @@ def update_ruleset(
145149
146150
Examples
147151
--------
152+
Overwrite rules in the existing ruleset.
148153
>>> wr.data_quality.update_ruleset(
149154
>>> name="ruleset",
150155
>>> new_name="my_ruleset",
151156
>>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
152157
>>>)
158+
159+
Update or insert rules in the existing ruleset.
160+
>>> wr.data_quality.update_ruleset(
161+
>>> name="ruleset",
162+
>>> mode="insert",
163+
>>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
164+
>>>)
153165
"""
154166
if (df_rules is not None and dqdl_rules) or (df_rules is None and not dqdl_rules):
155167
raise exceptions.InvalidArgumentCombination("You must pass either ruleset `df_rules` or `dqdl_rules`.")
168+
if mode not in ["overwrite", "upsert"]:
169+
raise exceptions.InvalidArgumentValue("`mode` must be one of 'overwrite' or 'upsert'.")
170+
171+
if mode == "upsert":
172+
df_existing = get_ruleset(name=name, boto3_session=boto3_session)
173+
df_existing = df_existing.set_index(keys=["rule_type", "parameter"], drop=False, verify_integrity=True)
174+
df_updated = _rules_to_df(dqdl_rules) if dqdl_rules is not None else df_rules
175+
df_updated = df_updated.set_index(keys=["rule_type", "parameter"], drop=False, verify_integrity=True)
176+
merged_df = pd.concat([df_existing[~df_existing.index.isin(df_updated.index)], df_updated])
177+
dqdl_rules = _create_dqdl(merged_df.reset_index(drop=True))
178+
else:
179+
dqdl_rules = _create_dqdl(df_rules) if df_rules is not None else dqdl_rules
180+
181+
args = {
182+
"Name": name,
183+
"Description": description,
184+
"Ruleset": dqdl_rules,
185+
}
186+
if updated_name:
187+
args["UpdatedName"] = updated_name
156188

157189
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
158-
dqdl_rules = _create_dqdl(df_rules) if df_rules is not None else dqdl_rules
159-
160190
try:
161-
client_glue.update_data_quality_ruleset(
162-
Name=name,
163-
UpdatedName=updated_name,
164-
Description=description,
165-
Ruleset=dqdl_rules,
166-
)
191+
client_glue.update_data_quality_ruleset(**args)
167192
except client_glue.exceptions.EntityNotFoundException as not_found:
168193
raise exceptions.ResourceDoesNotExist(f"Ruleset {name} does not exist.") from not_found
169194

@@ -327,7 +352,7 @@ def evaluate_ruleset(
327352
>>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
328353
>>>)
329354
>>> df_ruleset_results = wr.data_quality.evaluate_ruleset(
330-
>>> name=["ruleset1", "rulseset2"],
355+
>>> name="ruleset",
331356
>>> iam_role_arn=glue_data_quality_role,
332357
>>> )
333358
"""

awswrangler/data_quality/_get.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""AWS Glue Data Quality Get Module."""
22

3-
from typing import Optional, cast
3+
from typing import List, Optional, Union, cast
44

55
import boto3
66
import pandas as pd
@@ -9,28 +9,39 @@
99

1010

1111
def get_ruleset(
12-
name: str,
12+
name: Union[str, List[str]],
1313
boto3_session: Optional[boto3.Session] = None,
1414
) -> pd.DataFrame:
1515
"""Get a Data Quality ruleset.
1616
1717
Parameters
1818
----------
19-
name : str
20-
Ruleset name.
19+
name : str or list[str]
20+
Ruleset name or list of names.
2121
boto3_session : boto3.Session, optional
2222
Boto3 Session. If none, the default boto3 session is used.
2323
2424
Returns
2525
-------
2626
pd.DataFrame
27-
Data frame with ruleset details.
27+
Data frame with ruleset(s) details.
2828
2929
Examples
3030
--------
31+
Get single ruleset
3132
>>> import awswrangler as wr
3233
3334
>>> df_ruleset = wr.data_quality.get_ruleset(name="my_ruleset")
35+
36+
Get multiple rulesets. A column with the ruleset name is added to the data frame
37+
>>> df_rulesets = wr.data_quality.get_ruleset(name=["ruleset_1", "ruleset_2"])
3438
"""
35-
rules = cast(str, _get_ruleset(ruleset_name=name, boto3_session=boto3_session)["Ruleset"])
36-
return _rules_to_df(rules=rules)
39+
ruleset_names: List[str] = name if isinstance(name, list) else [name]
40+
dfs: List[pd.DataFrame] = []
41+
for ruleset_name in ruleset_names:
42+
rules = cast(str, _get_ruleset(ruleset_name=ruleset_name, boto3_session=boto3_session)["Ruleset"])
43+
df = _rules_to_df(rules=rules)
44+
if len(ruleset_names) > 1:
45+
df["ruleset"] = ruleset_name
46+
dfs.append(df)
47+
return pd.concat(dfs)

awswrangler/data_quality/_utils.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Dict, List, Optional, Tuple, Union, cast
99

1010
import boto3
11+
import botocore.exceptions
1112
import pandas as pd
1213

1314
from awswrangler import _utils, exceptions
@@ -23,12 +24,13 @@ def _parse_rules(rules: List[str]) -> List[Tuple[str, Optional[str], Optional[st
2324
for rule in rules:
2425
rule_type, remainder = tuple(rule.split(maxsplit=1))
2526
if remainder.startswith('"'):
26-
remainder_split = remainder.split(maxsplit=1)
27-
parameter = remainder_split[0].strip('"')
28-
expression = None if len(remainder_split) == 1 else remainder_split[1]
27+
expression_regex = r"\s+(?:[=><]|between\s+.+\s+and\s+|in\s+\[.+\]|matches\s+).*"
28+
expression_matches = re.findall(expression_regex, remainder)
29+
expression = None if len(expression_matches) == 0 else expression_matches[0].strip()
30+
parameter = remainder.split(expression)[0].strip() if expression else remainder
2931
else:
30-
parameter = None
3132
expression = remainder
33+
parameter = None
3234
parsed_rules.append((rule_type, parameter, expression))
3335
return parsed_rules
3436

@@ -115,10 +117,18 @@ def _get_ruleset_run(
115117
) -> Dict[str, Any]:
116118
session: boto3.Session = _utils.ensure_session(session=boto3_session)
117119
client_glue: boto3.client = _utils.client(service_name="glue", session=session)
118-
if run_type == "recommendation":
119-
response = client_glue.get_data_quality_rule_recommendation_run(RunId=run_id)
120-
elif run_type == "evaluation":
121-
response = client_glue.get_data_quality_ruleset_evaluation_run(RunId=run_id)
120+
f = (
121+
client_glue.get_data_quality_rule_recommendation_run
122+
if run_type == "recommendation"
123+
else client_glue.get_data_quality_ruleset_evaluation_run
124+
)
125+
response = _utils.try_it(
126+
f=f,
127+
ex=botocore.exceptions.ClientError,
128+
ex_code="ThrottlingException",
129+
max_num_tries=5,
130+
RunId=run_id,
131+
)
122132
return cast(Dict[str, Any], response)
123133

124134

@@ -148,7 +158,14 @@ def _get_ruleset(
148158
) -> Dict[str, Any]:
149159
boto3_session = _utils.ensure_session(session=boto3_session)
150160
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
151-
return cast(Dict[str, Any], client_glue.get_data_quality_ruleset(Name=ruleset_name))
161+
response = _utils.try_it(
162+
f=client_glue.get_data_quality_ruleset,
163+
ex=botocore.exceptions.ClientError,
164+
ex_code="ThrottlingException",
165+
max_num_tries=5,
166+
Name=ruleset_name,
167+
)
168+
return cast(Dict[str, Any], response)
152169

153170

154171
def _get_data_quality_results(

0 commit comments

Comments
 (0)