Skip to content

Commit 3f86d29

Browse files
committed
add custom mask functionalities
1 parent d51f73a commit 3f86d29

File tree

2 files changed

+186
-16
lines changed

2 files changed

+186
-16
lines changed

aws_lambda_powertools/utilities/data_masking/base.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import ast
34
import functools
45
import logging
56
import warnings
@@ -94,15 +95,52 @@ def erase(self, data: tuple, fields: list[str]) -> tuple[str]: ...
9495
@overload
9596
def erase(self, data: dict, fields: list[str]) -> dict: ...
9697

97-
def erase(self, data: Sequence | Mapping, fields: list[str] | None = None) -> str | list[str] | tuple[str] | dict:
98-
return self._apply_action(data=data, fields=fields, action=self.provider.erase)
98+
@overload
99+
def erase(
100+
self,
101+
data: dict,
102+
fields: list[str],
103+
custom_mask: bool | None = None,
104+
mask_pattern: str | None = None,
105+
regex_pattern: str | None = None,
106+
mask_format: str | None = None,
107+
) -> dict: ...
108+
109+
def erase(
110+
self,
111+
data: Sequence | Mapping,
112+
fields: list[str] | None = None,
113+
custom_mask: bool | None = None,
114+
mask_pattern: str | None = None,
115+
regex_pattern: str | None = None,
116+
mask_format: str | None = None,
117+
masking_rules: dict | None = None,
118+
) -> str | list[str] | tuple[str] | dict:
119+
if not data:
120+
return data
121+
if masking_rules:
122+
return self._apply_masking_rules(data, masking_rules)
123+
else:
124+
return self._apply_action(
125+
data=data,
126+
fields=fields,
127+
action=self.provider.erase,
128+
custom_mask=custom_mask,
129+
mask_pattern=mask_pattern,
130+
regex_pattern=regex_pattern,
131+
mask_format=mask_format,
132+
)
99133

100134
def _apply_action(
101135
self,
102136
data,
103137
fields: list[str] | None,
104138
action: Callable,
105139
provider_options: dict | None = None,
140+
custom_mask: bool | None = None,
141+
mask_pattern: str | None = None,
142+
regex_pattern: str | None = None,
143+
mask_format: str | None = None,
106144
**encryption_context: str,
107145
):
108146
"""
@@ -136,18 +174,34 @@ def _apply_action(
136174
fields=fields,
137175
action=action,
138176
provider_options=provider_options,
177+
custom_mask=custom_mask,
178+
mask_pattern=mask_pattern,
179+
regex_pattern=regex_pattern,
180+
mask_format=mask_format,
139181
**encryption_context,
140182
)
141183
else:
142184
logger.debug(f"Running action {action.__name__} with the entire data")
143-
return action(data=data, provider_options=provider_options, **encryption_context)
185+
return action(
186+
data=data,
187+
provider_options=provider_options,
188+
custom_mask=custom_mask,
189+
mask_pattern=mask_pattern,
190+
regex_pattern=regex_pattern,
191+
mask_format=mask_format,
192+
**encryption_context,
193+
)
144194

145195
def _apply_action_to_fields(
146196
self,
147197
data: dict | str,
148198
fields: list,
149199
action: Callable,
150200
provider_options: dict | None = None,
201+
custom_mask: bool | None = None,
202+
mask_pattern: str | None = None,
203+
regex_pattern: str | None = None,
204+
mask_format: str | None = None,
151205
**encryption_context: str,
152206
) -> dict | str:
153207
"""
@@ -194,6 +248,8 @@ def _apply_action_to_fields(
194248
new_dict = {'a': {'b': {'c': '*****'}}, 'x': {'y': '*****'}}
195249
```
196250
"""
251+
if not fields:
252+
raise ValueError("Fields parameter cannot be empty")
197253

198254
data_parsed: dict = self._normalize_data_to_parse(fields, data)
199255

@@ -204,6 +260,10 @@ def _apply_action_to_fields(
204260
self._call_action,
205261
action=action,
206262
provider_options=provider_options,
263+
custom_mask=custom_mask,
264+
mask_pattern=mask_pattern,
265+
regex_pattern=regex_pattern,
266+
mask_format=mask_format,
207267
**encryption_context, # type: ignore[arg-type]
208268
)
209269

@@ -225,12 +285,6 @@ def _apply_action_to_fields(
225285
# For in-place updates, json_parse accepts a callback function
226286
# that receives 3 args: field_value, fields, field_name
227287
# We create a partial callback to pre-populate known provider options (action, provider opts, enc ctx)
228-
update_callback = functools.partial(
229-
self._call_action,
230-
action=action,
231-
provider_options=provider_options,
232-
**encryption_context, # type: ignore[arg-type]
233-
)
234288

235289
json_parse.update(
236290
data_parsed,
@@ -239,13 +293,60 @@ def _apply_action_to_fields(
239293

240294
return data_parsed
241295

296+
def _apply_masking_rules(self, data: dict, masking_rules: dict) -> dict:
297+
"""
298+
Apply masking rules to data, supporting different rules for each field.
299+
"""
300+
result = data.copy()
301+
302+
for path, rule in masking_rules.items():
303+
try:
304+
# Handle nested paths (e.g., 'address.street')
305+
parts = path.split(".")
306+
current = result
307+
308+
for part in parts[:-1]:
309+
if isinstance(current[part], str) and current[part].startswith("{"):
310+
try:
311+
current[part] = ast.literal_eval(current[part])
312+
except (ValueError, SyntaxError):
313+
continue
314+
current = current[part]
315+
316+
final_field = parts[-1]
317+
318+
# Apply masking rule to the target field
319+
if final_field in current:
320+
current[final_field] = self.provider.erase(str(current[final_field]), **rule)
321+
322+
except (KeyError, TypeError, AttributeError):
323+
# Log warning if field not found or invalid path
324+
warnings.warn(f"Could not apply masking rule for path: {path}", stacklevel=2)
325+
continue
326+
327+
return result
328+
329+
def _mask_nested_field(self, data: dict, field_path: str, mask_function):
330+
keys = field_path.split(".")
331+
current = data
332+
for key in keys[:-1]:
333+
current = current.get(key, {})
334+
if not isinstance(current, dict):
335+
return # Caminho inválido
336+
if keys[-1] in current:
337+
current[keys[-1]] = mask_function(current[keys[-1]])
338+
242339
@staticmethod
243340
def _call_action(
244341
field_value: Any,
245342
fields: dict[str, Any],
246343
field_name: str,
247344
action: Callable,
248345
provider_options: dict[str, Any] | None = None,
346+
custom_mask: bool | None = None,
347+
mask_pattern: str | None = None,
348+
regex_pattern: str | None = None,
349+
mask_format: str | None = None,
249350
**encryption_context,
250351
) -> None:
251352
"""
@@ -263,7 +364,15 @@ def _call_action(
263364
Returns:
264365
- fields[field_name]: Returns the processed field value
265366
"""
266-
fields[field_name] = action(field_value, provider_options=provider_options, **encryption_context)
367+
fields[field_name] = action(
368+
field_value,
369+
provider_options=provider_options,
370+
custom_mask=custom_mask,
371+
mask_pattern=mask_pattern,
372+
regex_pattern=regex_pattern,
373+
mask_format=mask_format,
374+
**encryption_context,
375+
)
267376
return fields[field_name]
268377

269378
def _normalize_data_to_parse(self, fields: list, data: str | dict) -> dict:

aws_lambda_powertools/utilities/data_masking/provider/base.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22

33
import functools
44
import json
5+
import re
56
from typing import Any, Callable, Iterable
67

78
from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING
89

10+
PRESERVE_CHARS = set("-_. ")
11+
_regex_cache = {}
12+
913

1014
class BaseProvider:
1115
"""
@@ -63,7 +67,16 @@ def decrypt(self, data, provider_options: dict | None = None, **encryption_conte
6367
"""
6468
raise NotImplementedError("Subclasses must implement decrypt()")
6569

66-
def erase(self, data, **kwargs) -> Iterable[str]:
70+
def erase(
71+
self,
72+
data,
73+
custom_mask: bool | None = None,
74+
mask_pattern: str | None = None,
75+
regex_pattern: str | None = None,
76+
mask_format: str | None = None,
77+
masking_rules: dict | None = None,
78+
**kwargs,
79+
) -> Iterable[str]:
6780
"""
6881
This method irreversibly erases data.
6982
@@ -72,10 +85,58 @@ def erase(self, data, **kwargs) -> Iterable[str]:
7285
7386
If the data to be erased is of an iterable type like `list`, `tuple`,
7487
or `set`, this method will return a new object of the same type as the
75-
input data but with each element replaced by the string "*****".
88+
input data but with each element replaced by the string "*****" or following one of the custom masks.
7689
"""
77-
if isinstance(data, (str, dict, bytes)):
90+
result = DATA_MASKING_STRING
91+
92+
if data:
93+
if isinstance(data, str):
94+
if custom_mask:
95+
if mask_pattern:
96+
result = self._pattern_mask(data, mask_pattern)
97+
elif regex_pattern and mask_format:
98+
result = self._regex_mask(data, regex_pattern, mask_format)
99+
else:
100+
result = self._custom_erase(data, **kwargs)
101+
elif isinstance(data, dict):
102+
if masking_rules:
103+
result = self._apply_masking_rules(data, masking_rules)
104+
elif isinstance(data, (list, tuple, set)):
105+
result = type(data)(
106+
self.erase(
107+
item,
108+
custom_mask=custom_mask,
109+
mask_pattern=mask_pattern,
110+
regex_pattern=regex_pattern,
111+
mask_format=mask_format,
112+
masking_rules=masking_rules,
113+
**kwargs,
114+
)
115+
for item in data
116+
)
117+
118+
return result
119+
120+
def _apply_masking_rules(self, data: dict, masking_rules: dict) -> dict:
121+
return {
122+
key: self.erase(str(value), **masking_rules[key]) if key in masking_rules else str(value)
123+
for key, value in data.items()
124+
}
125+
126+
def _pattern_mask(self, data: str, pattern: str) -> str:
127+
return pattern[: len(data)] if len(pattern) >= len(data) else pattern
128+
129+
def _regex_mask(self, data: str, regex_pattern: str, mask_format: str) -> str:
130+
try:
131+
if regex_pattern not in _regex_cache:
132+
_regex_cache[regex_pattern] = re.compile(regex_pattern)
133+
return _regex_cache[regex_pattern].sub(mask_format, data)
134+
except re.error:
78135
return DATA_MASKING_STRING
79-
elif isinstance(data, (list, tuple, set)):
80-
return type(data)([DATA_MASKING_STRING] * len(data))
81-
return DATA_MASKING_STRING
136+
137+
def _custom_erase(self, data: str, **kwargs) -> str:
138+
if not data:
139+
return ""
140+
141+
# Use join with list comprehension instead of building list incrementally
142+
return "".join("*" if char not in PRESERVE_CHARS else char for char in data)

0 commit comments

Comments
 (0)