|
| 1 | +"""PII Purifier using Mellea Framework.""" |
| 2 | + |
| 3 | +import spacy |
| 4 | + |
| 5 | +from cli.serve.models import ChatMessage |
| 6 | +import mellea |
| 7 | +from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B |
| 8 | +from mellea.stdlib.base import ModelOutputThunk |
| 9 | +from mellea.stdlib.requirement import req, simple_validate |
| 10 | +from mellea.stdlib.sampling import RejectionSamplingStrategy |
| 11 | +from mellea.stdlib.sampling.types import SamplingResult |
| 12 | + |
| 13 | + |
| 14 | +def has_potential_pii(text: str) -> bool: |
| 15 | + """Quick heuristic check for potential PII patterns using spaCy NER.""" |
| 16 | + nlp = spacy.load("en_core_web_sm") |
| 17 | + doc = nlp(text) |
| 18 | + |
| 19 | + # Check for person names and locations |
| 20 | + pii_entities = ["PERSON", "GPE", "LOC", "ORG"] |
| 21 | + for ent in doc.ents: |
| 22 | + if ent.label_ in pii_entities: |
| 23 | + return True |
| 24 | + |
| 25 | + # Additional simple checks for email/phone patterns |
| 26 | + tokens = [token.text for token in doc] |
| 27 | + for token in tokens: |
| 28 | + # Email-like structure |
| 29 | + if "@" in token and "." in token: |
| 30 | + return True |
| 31 | + # Phone-like (contains multiple digits) |
| 32 | + if sum(c.isdigit() for c in token) >= 7: |
| 33 | + return True |
| 34 | + |
| 35 | + return False |
| 36 | + |
| 37 | + |
| 38 | +def pii_remove_validate( |
| 39 | + m: mellea.MelleaSession, |
| 40 | + text: str, |
| 41 | + requirements: list[str] | None = None, |
| 42 | + loop_budget: int = 3, |
| 43 | + model_options: None | dict = None, |
| 44 | +) -> ModelOutputThunk | SamplingResult | str: |
| 45 | + """PII scrubbing in mellea with validation.""" |
| 46 | + # Extra requirements if any. |
| 47 | + requirements = requirements if requirements else [] |
| 48 | + result = m.instruct( |
| 49 | + f"Remove all personally identifiable information from the following text " |
| 50 | + f"and replace it with XXX:\n\n{text}", |
| 51 | + requirements=[ |
| 52 | + req( |
| 53 | + "Replace all names,email addresses, phone numbers, and addresses with XXX" |
| 54 | + ), |
| 55 | + req("Preserve non-PII content unchanged"), |
| 56 | + req( |
| 57 | + "Output must not contain PII", |
| 58 | + validation_fn=simple_validate( |
| 59 | + lambda output: not has_potential_pii(output) |
| 60 | + ), |
| 61 | + ), |
| 62 | + *requirements, |
| 63 | + ], |
| 64 | + strategy=RejectionSamplingStrategy(loop_budget=loop_budget), |
| 65 | + return_sampling_results=True, |
| 66 | + model_options=model_options, |
| 67 | + ) |
| 68 | + if result.success: |
| 69 | + return result |
| 70 | + else: |
| 71 | + return "The Validation Failed" |
| 72 | + |
| 73 | + |
| 74 | +session = mellea.start_session(model_id=IBM_GRANITE_4_MICRO_3B) |
| 75 | + |
| 76 | + |
| 77 | +def serve( |
| 78 | + input: list[ChatMessage], |
| 79 | + requirements: list[str] | None = None, |
| 80 | + model_options: None | dict = None, |
| 81 | +) -> ModelOutputThunk | SamplingResult | str: |
| 82 | + """Simple serve example to do PII stuff.""" |
| 83 | + message = input[-1].content |
| 84 | + result = pii_remove_validate( |
| 85 | + session, message, requirements=requirements, model_options=model_options |
| 86 | + ) |
| 87 | + return result |
0 commit comments