Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdp/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SubIfASRSubstitution,
SubMakeLowercase,
SubRegex,
SearchRegex,
)
from sdp.processors.modify_manifest.data_to_dropbool import (
DropASRError,
Expand Down
46 changes: 46 additions & 0 deletions sdp/processors/modify_manifest/data_to_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,49 @@ def finalize(self, metrics):
for word, count in total_counter_sorted.items():
logger.info(f"{word} {count}")
super().finalize(metrics)


class SearchRegex(BaseParallelProcessor):
"""Searches for patterns in the input string.

Args:
search_patterns (list[str]): List of search patterns.
text_key (str): Key in the data entry containing the text to search.
output_key (str): Key in the data entry to store the output value indicating if any pattern has been found.
"""

def __init__(
self,
search_patterns: List[str],
text_key: str = "text",
output_key: str = "pattern_found",
**kwargs,
):
super().__init__(**kwargs)
self.search_patterns = search_patterns
self.text_key = text_key
self.output_key = output_key

def process_dataset_entry(self, data_entry) -> List:
"""Searches for each pattern in the input text."""
search_results = {}

text_in = data_entry[self.text_key]
pattern_found = False

for pattern in self.search_patterns:
found = bool(re.search(pattern, text_in))
search_results[pattern] = found
if found:
pattern_found = True

data_entry[self.output_key] = pattern_found

return [DataEntry(data=data_entry, metrics=pattern_found)]

def finalize(self, metrics):
"""Reports counts of how many data entries contained patterns."""
print(f"Samples amount which contain patterns: {sum(metrics)}")
print(f"Samples amount which don't contain patterns: {len(metrics) - sum(metrics)}")

super().finalize(metrics)