diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index 23fa95aa..9582c885 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -144,6 +144,9 @@ Data modifications :annotation: :noindex: +.. autodata:: sdp.processors.SearchRegex + :annotation: + .. autodata:: sdp.processors.SubMakeLowercase :annotation: diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 3dd04a4c..a04fa12c 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -64,6 +64,7 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, + SearchRegex, FfmpegConvert, ) from sdp.processors.modify_manifest.data_to_dropbool import ( diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index f2fb6011..cc2a3133 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -577,4 +577,52 @@ def finalize(self, metrics): total_counter_sorted = dict(sorted(total_counter.items(), key=lambda x: x[1], reverse=True)) for word, count in total_counter_sorted.items(): logger.info(f"{word} {count}") - super().finalize(metrics) \ No newline at end of file + + super().finalize(metrics) + + +class SearchRegex(BaseParallelProcessor): + """Searches for patterns in the input string. + + Args: + search_patterns (list[str]): List of search patterns. + text_key (str): Key in the data entry containing the text to search. + output_key (str): Key in the data entry to store the output value indicating if any pattern has been found. + """ + + def __init__( + self, + search_patterns: List[str], + text_key: str = "text", + output_key: str = "pattern_found", + **kwargs, + ): + super().__init__(**kwargs) + self.search_patterns = search_patterns + self.text_key = text_key + self.output_key = output_key + + def process_dataset_entry(self, data_entry) -> List: + """Searches for each pattern in the input text.""" + search_results = {} + + text_in = data_entry[self.text_key] + pattern_found = False + + for pattern in self.search_patterns: + found = bool(re.search(pattern, text_in)) + search_results[pattern] = found + if found: + pattern_found = True + + data_entry[self.output_key] = pattern_found + + return [DataEntry(data=data_entry, metrics=pattern_found)] + + def finalize(self, metrics): + """Reports counts of how many data entries contained patterns.""" + print(f"Samples amount which contain patterns: {sum(metrics)}") + print(f"Samples amount which don't contain patterns: {len(metrics) - sum(metrics)}") + + super().finalize(metrics) + \ No newline at end of file diff --git a/tests/test_data_to_data.py b/tests/test_data_to_data.py index 5bd75f47..260efdba 100644 --- a/tests/test_data_to_data.py +++ b/tests/test_data_to_data.py @@ -19,6 +19,7 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, + SearchRegex ) test_params_list = [] @@ -90,6 +91,16 @@ ] ) +test_params_list.extend( + [ + ( + SearchRegex, + {"search_patterns": ["[^a-zA-Z\\s]+"]}, + {"text": "Hola, bienvenido seas a este Canal de Ministerio Latino por Cristo."}, + {"text": "Hola, bienvenido seas a este Canal de Ministerio Latino por Cristo.", "pattern_found": True}, + ), + ] +) @pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str) def test_data_to_data(test_class, class_kwargs, test_input, expected_output):