1010
1111from awswrangler import _utils , exceptions
1212from awswrangler ._config import apply_configs
13+ from awswrangler .data_quality ._get import get_ruleset
1314from awswrangler .data_quality ._utils import (
1415 _create_datasource ,
1516 _get_data_quality_results ,
@@ -27,7 +28,7 @@ def _create_dqdl(
2728 """Create DQDL from pandas data frame."""
2829 rules = []
2930 for rule_type , parameter , expression in df_rules .itertuples (index = False ):
30- parameter_str = f' " { parameter } " ' if parameter else " "
31+ parameter_str = f" { parameter } " if parameter else " "
3132 expression_str = expression if expression else ""
3233 rules .append (f"{ rule_type } { parameter_str } { expression_str } " )
3334 return "Rules = [ " + ", " .join (rules ) + " ]"
@@ -85,7 +86,7 @@ def create_ruleset(
8586 >>> df = pd.DataFrame({"c0": [0, 1, 2], "c1": [0, 1, 2], "c2": [0, 0, 1]})
8687 >>> df_rules = pd.DataFrame({
8788 >>> "rule_type": ["RowCount", "IsComplete", "Uniqueness"],
88- >>> "parameter": [None, "c0", "c0"],
89+ >>> "parameter": [None, ' "c0"', ' "c0"' ],
8990 >>> "expression": ["between 1 and 6", None, "> 0.95"],
9091 >>> })
9192 >>> wr.s3.to_parquet(df, path, dataset=True, database="database", table="table")
@@ -121,6 +122,7 @@ def create_ruleset(
121122def update_ruleset (
122123 name : str ,
123124 updated_name : Optional [str ] = None ,
125+ mode : str = "overwrite" ,
124126 df_rules : Optional [pd .DataFrame ] = None ,
125127 dqdl_rules : Optional [str ] = None ,
126128 description : str = "" ,
@@ -134,6 +136,8 @@ def update_ruleset(
134136 Ruleset name.
135137 updated_name : str
136138 New ruleset name if renaming an existing ruleset.
139+ mode : str
140+ overwrite (default) or upsert.
137141 df_rules : str, optional
138142 Data frame with `rule_type`, `parameter`, and `expression` columns.
139143 dqdl_rules : str, optional
@@ -145,25 +149,46 @@ def update_ruleset(
145149
146150 Examples
147151 --------
152+ Overwrite rules in the existing ruleset.
148153 >>> wr.data_quality.update_ruleset(
149154 >>> name="ruleset",
150155 >>> new_name="my_ruleset",
151156 >>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
152157 >>>)
158+
159+ Update or insert rules in the existing ruleset.
160+ >>> wr.data_quality.update_ruleset(
161+ >>> name="ruleset",
162+ >>> mode="insert",
163+ >>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
164+ >>>)
153165 """
154166 if (df_rules is not None and dqdl_rules ) or (df_rules is None and not dqdl_rules ):
155167 raise exceptions .InvalidArgumentCombination ("You must pass either ruleset `df_rules` or `dqdl_rules`." )
168+ if mode not in ["overwrite" , "upsert" ]:
169+ raise exceptions .InvalidArgumentValue ("`mode` must be one of 'overwrite' or 'upsert'." )
170+
171+ if mode == "upsert" :
172+ df_existing = get_ruleset (name = name , boto3_session = boto3_session )
173+ df_existing = df_existing .set_index (keys = ["rule_type" , "parameter" ], drop = False , verify_integrity = True )
174+ df_updated = _rules_to_df (dqdl_rules ) if dqdl_rules is not None else df_rules
175+ df_updated = df_updated .set_index (keys = ["rule_type" , "parameter" ], drop = False , verify_integrity = True )
176+ merged_df = pd .concat ([df_existing [~ df_existing .index .isin (df_updated .index )], df_updated ])
177+ dqdl_rules = _create_dqdl (merged_df .reset_index (drop = True ))
178+ else :
179+ dqdl_rules = _create_dqdl (df_rules ) if df_rules is not None else dqdl_rules
180+
181+ args = {
182+ "Name" : name ,
183+ "Description" : description ,
184+ "Ruleset" : dqdl_rules ,
185+ }
186+ if updated_name :
187+ args ["UpdatedName" ] = updated_name
156188
157189 client_glue : boto3 .client = _utils .client (service_name = "glue" , session = boto3_session )
158- dqdl_rules = _create_dqdl (df_rules ) if df_rules is not None else dqdl_rules
159-
160190 try :
161- client_glue .update_data_quality_ruleset (
162- Name = name ,
163- UpdatedName = updated_name ,
164- Description = description ,
165- Ruleset = dqdl_rules ,
166- )
191+ client_glue .update_data_quality_ruleset (** args )
167192 except client_glue .exceptions .EntityNotFoundException as not_found :
168193 raise exceptions .ResourceDoesNotExist (f"Ruleset { name } does not exist." ) from not_found
169194
@@ -327,7 +352,7 @@ def evaluate_ruleset(
327352 >>> dqdl_rules="Rules = [ RowCount between 1 and 3 ]",
328353 >>>)
329354 >>> df_ruleset_results = wr.data_quality.evaluate_ruleset(
330- >>> name=["ruleset1", "rulseset2"] ,
355+ >>> name="ruleset" ,
331356 >>> iam_role_arn=glue_data_quality_role,
332357 >>> )
333358 """
0 commit comments