|
| 1 | +import re |
| 2 | +from typing import Callable, Dict, List, Optional |
| 3 | + |
| 4 | +from guardrails.logger import logger |
| 5 | +from guardrails.validator_base import ( |
| 6 | + FailResult, |
| 7 | + PassResult, |
| 8 | + ValidationResult, |
| 9 | + Validator, |
| 10 | + register_validator, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +try: |
| 15 | + import nltk # type: ignore |
| 16 | +except ImportError: |
| 17 | + nltk = None # type: ignore |
| 18 | + |
| 19 | +if nltk is not None: |
| 20 | + try: |
| 21 | + nltk.data.find("tokenizers/punkt") |
| 22 | + except LookupError: |
| 23 | + nltk.download("punkt") |
| 24 | + |
| 25 | +try: |
| 26 | + import spacy |
| 27 | +except ImportError: |
| 28 | + spacy = None |
| 29 | + |
| 30 | +@register_validator(name="competitor-check", data_type="string") |
| 31 | +class CompetitorCheck(Validator): |
| 32 | + """Validates that LLM-generated text is not naming any competitors from a |
| 33 | + given list. |
| 34 | +
|
| 35 | + In order to use this validator you need to provide an extensive list of the |
| 36 | + competitors you want to avoid naming including all common variations. |
| 37 | +
|
| 38 | + Args: |
| 39 | + competitors (List[str]): List of competitors you want to avoid naming |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + competitors: List[str], |
| 45 | + on_fail: Optional[Callable] = None, |
| 46 | + ): |
| 47 | + super().__init__(competitors=competitors, on_fail=on_fail) |
| 48 | + self._competitors = competitors |
| 49 | + model = "en_core_web_trf" |
| 50 | + if spacy is None: |
| 51 | + raise ImportError( |
| 52 | + "You must install spacy in order to use the CompetitorCheck validator." |
| 53 | + ) |
| 54 | + |
| 55 | + if not spacy.util.is_package(model): |
| 56 | + logger.info( |
| 57 | + f"Spacy model {model} not installed. " |
| 58 | + "Download should start now and take a few minutes." |
| 59 | + ) |
| 60 | + spacy.cli.download(model) # type: ignore |
| 61 | + |
| 62 | + self.nlp = spacy.load(model) |
| 63 | + |
| 64 | + def exact_match(self, text: str, competitors: List[str]) -> List[str]: |
| 65 | + """Performs exact match to find competitors from a list in a given |
| 66 | + text. |
| 67 | +
|
| 68 | + Args: |
| 69 | + text (str): The text to search for competitors. |
| 70 | + competitors (list): A list of competitor entities to match. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + list: A list of matched entities. |
| 74 | + """ |
| 75 | + |
| 76 | + found_entities = [] |
| 77 | + for entity in competitors: |
| 78 | + pattern = rf"\b{re.escape(entity)}\b" |
| 79 | + match = re.search(pattern.lower(), text.lower()) |
| 80 | + if match: |
| 81 | + found_entities.append(entity) |
| 82 | + return found_entities |
| 83 | + |
| 84 | + def perform_ner(self, text: str, nlp) -> List[str]: |
| 85 | + """Performs named entity recognition on text using a provided NLP |
| 86 | + model. |
| 87 | +
|
| 88 | + Args: |
| 89 | + text (str): The text to perform named entity recognition on. |
| 90 | + nlp: The NLP model to use for entity recognition. |
| 91 | +
|
| 92 | + Returns: |
| 93 | + entities: A list of entities found. |
| 94 | + """ |
| 95 | + |
| 96 | + doc = nlp(text) |
| 97 | + entities = [] |
| 98 | + for ent in doc.ents: |
| 99 | + entities.append(ent.text) |
| 100 | + return entities |
| 101 | + |
| 102 | + def is_entity_in_list(self, entities: List[str], competitors: List[str]) -> List: |
| 103 | + """Checks if any entity from a list is present in a given list of |
| 104 | + competitors. |
| 105 | +
|
| 106 | + Args: |
| 107 | + entities (list): A list of entities to check |
| 108 | + competitors (list): A list of competitor names to match |
| 109 | +
|
| 110 | + Returns: |
| 111 | + List: List of found competitors |
| 112 | + """ |
| 113 | + |
| 114 | + found_competitors = [] |
| 115 | + for entity in entities: |
| 116 | + for item in competitors: |
| 117 | + pattern = rf"\b{re.escape(item)}\b" |
| 118 | + match = re.search(pattern.lower(), entity.lower()) |
| 119 | + if match: |
| 120 | + found_competitors.append(item) |
| 121 | + return found_competitors |
| 122 | + |
| 123 | + def validate(self, value: str, metadata=Dict) -> ValidationResult: |
| 124 | + """Checks a text to find competitors' names in it. |
| 125 | +
|
| 126 | + While running, store sentences naming competitors and generate a fixed output |
| 127 | + filtering out all flagged sentences. |
| 128 | +
|
| 129 | + Args: |
| 130 | + value (str): The value to be validated. |
| 131 | + metadata (Dict, optional): Additional metadata. Defaults to empty dict. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + ValidationResult: The validation result. |
| 135 | + """ |
| 136 | + |
| 137 | + if nltk is None: |
| 138 | + raise ImportError( |
| 139 | + "`nltk` library is required for `competitors-check` validator. " |
| 140 | + "Please install it with `poetry add nltk`." |
| 141 | + ) |
| 142 | + sentences = nltk.sent_tokenize(value) |
| 143 | + flagged_sentences = [] |
| 144 | + filtered_sentences = [] |
| 145 | + list_of_competitors_found = [] |
| 146 | + |
| 147 | + for sentence in sentences: |
| 148 | + entities = self.exact_match(sentence, self._competitors) |
| 149 | + if entities: |
| 150 | + ner_entities = self.perform_ner(sentence, self.nlp) |
| 151 | + found_competitors = self.is_entity_in_list(ner_entities, entities) |
| 152 | + |
| 153 | + if found_competitors: |
| 154 | + flagged_sentences.append((found_competitors, sentence)) |
| 155 | + list_of_competitors_found.append(found_competitors) |
| 156 | + logger.debug(f"Found: {found_competitors} named in '{sentence}'") |
| 157 | + else: |
| 158 | + filtered_sentences.append(sentence) |
| 159 | + |
| 160 | + else: |
| 161 | + filtered_sentences.append(sentence) |
| 162 | + |
| 163 | + filtered_output = " ".join(filtered_sentences) |
| 164 | + |
| 165 | + if len(flagged_sentences): |
| 166 | + return FailResult( |
| 167 | + error_message=( |
| 168 | + f"Found the following competitors: {list_of_competitors_found}. " |
| 169 | + "Please avoid naming those competitors next time" |
| 170 | + ), |
| 171 | + fix_value=filtered_output, |
| 172 | + ) |
| 173 | + else: |
| 174 | + return PassResult() |
0 commit comments