Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/optimized-pattern-aggregator.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Improved `PatternPairAggregator` performance by using incremental delimiter detection.
- Added `MatchAction` enum to control pattern matching behavior (`REMOVE`, `KEEP`, `AGGREGATE`).
- Deprecated `add_pattern_pair` in favor of `add_pattern` with `MatchAction`.
289 changes: 145 additions & 144 deletions src/pipecat/utils/text/pattern_pair_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
support for custom handlers and configurable actions for when a pattern is found.
"""

import re
from dataclasses import dataclass
from enum import Enum
from typing import AsyncIterator, Awaitable, Callable, List, Optional, Tuple

Expand All @@ -27,8 +27,8 @@ class MatchAction(Enum):
Parameters:
REMOVE: The text along with its delimiters will be removed from the streaming text.
Sentence aggregation will continue on as if this text did not exist.
KEEP: The delimiters will be removed, but the content between them will be kept.
Sentence aggregation will continue on with the internal text included.
KEEP: The matched pattern (including delimiters) will be kept in the text.
Sentence aggregation will continue on with the full matched text included.
AGGREGATE: The delimiters will be removed and the content between will be treated
as a separate aggregation. Any text before the start of the pattern will be
returned early, whether or not a complete sentence was found. Then the pattern
Expand Down Expand Up @@ -72,6 +72,28 @@ def __str__(self) -> str:
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"


@dataclass(frozen=True)
class _PatternSpec:
type: str
start: str
end: str
action: MatchAction

@property
def start_len(self) -> int:
return len(self.start)

@property
def end_len(self) -> int:
return len(self.end)


@dataclass
class _OpenPattern:
spec: _PatternSpec
start_idx: int


class PatternPairAggregator(SimpleTextAggregator):
"""Aggregator that identifies and processes content between pattern pairs.

Expand Down Expand Up @@ -99,8 +121,11 @@ def __init__(self, **kwargs):
"""
super().__init__()
self._patterns = {}
self._specs = {}
self._handlers = {}
self._last_processed_position = 0 # Track where we last checked for complete patterns
self._open = []
self._start_by_last = {}
self._end_by_last = {}

@property
def text(self) -> Aggregation:
Expand All @@ -109,14 +134,10 @@ def text(self) -> Aggregation:
Returns:
The text that has been accumulated in the buffer.
"""
pattern_start = self._match_start_of_pattern(self._text)
stripped_text = self._text.strip()
type = (
pattern_start[1].get("type", AggregationType.SENTENCE)
if pattern_start
else AggregationType.SENTENCE
)
return Aggregation(text=stripped_text, type=type)
if self._open:
return Aggregation(text=stripped_text, type=self._open[-1].spec.type)
return Aggregation(text=stripped_text, type=AggregationType.SENTENCE)

def add_pattern(
self,
Expand Down Expand Up @@ -150,12 +171,36 @@ def add_pattern(
raise ValueError(
f"The aggregation type '{type}' is reserved for default behavior and can not be used for custom patterns."
)
if not start_pattern or not end_pattern:
raise ValueError("start_pattern and end_pattern must be non-empty strings.")

old = self._specs.get(type)
if old is not None:
try:
self._start_by_last.get(old.start[-1], []).remove(old)
except ValueError:
pass
try:
self._end_by_last.get(old.end[-1], []).remove(old)
except ValueError:
pass

spec = _PatternSpec(type=type, start=start_pattern, end=end_pattern, action=action)

self._patterns[type] = {
"start": start_pattern,
"end": end_pattern,
"type": type,
"action": action,
}
self._specs[type] = spec

self._start_by_last.setdefault(start_pattern[-1], []).append(spec)
self._end_by_last.setdefault(end_pattern[-1], []).append(spec)

self._start_by_last[start_pattern[-1]].sort(key=lambda s: s.start_len, reverse=True)
self._end_by_last[end_pattern[-1]].sort(key=lambda s: s.end_len, reverse=True)

return self

def add_pattern_pair(
Expand Down Expand Up @@ -217,100 +262,63 @@ def on_pattern_match(
self._handlers[type] = handler
return self

async def _process_complete_patterns(
self, text: str, last_processed_position: int = 0
) -> Tuple[List[PatternMatch], str]:
"""Process newly complete pattern pairs in the text.

Searches for pattern pairs that have been completed since last_processed_position,
calls the appropriate handlers, and optionally removes the matches.

Args:
text: The text to process for pattern matches.
last_processed_position: The position in text that was already processed.
Only patterns that end at or after this position will be processed.
def _push_open_if_start_delimiter(
self, ignore_spec: Optional[_PatternSpec] = None
) -> Optional[_OpenPattern]:
if not self._text:
return None
last = self._text[-1]
for spec in self._start_by_last.get(last, []):
if ignore_spec and spec is ignore_spec:
continue
if len(self._text) >= spec.start_len and self._text.endswith(spec.start):
start_idx = len(self._text) - spec.start_len
op = _OpenPattern(spec=spec, start_idx=start_idx)
self._open.append(op)
return op
return None

Returns:
Tuple of (all_matches, processed_text) where:
async def _close_one_if_end_delimiter(self) -> Optional[Tuple[_PatternSpec, PatternMatch]]:
if not self._text or not self._open:
return None

- all_matches is a list of all pattern matches found. Note: There really should only ever be 1.
- processed_text is the text after processing patterns. If no patterns are found, it will be the same as input text.
"""
all_matches = []
processed_text = text

for type, pattern_info in self._patterns.items():
# Escape special regex characters in the patterns
start = re.escape(pattern_info["start"])
end = re.escape(pattern_info["end"])
action = pattern_info["action"]

# Create regex to match from start pattern to end pattern
# The .*? is non-greedy to handle nested patterns
regex = f"{start}(.*?){end}"

# Find all matches
match_iter = re.finditer(regex, processed_text, re.DOTALL)
matches = list(match_iter) # Convert to list for safe iteration

for match in matches:
content = match.group(1) # Content between patterns
full_match = match.group(0) # Full match including patterns

# Create pattern match object
pattern_match = PatternMatch(
content=content.strip(), type=type, full_match=full_match
)

# Check if this pattern was already processed
already_processed = match.end() <= last_processed_position
last = self._text[-1]
for spec in self._end_by_last.get(last, []):
if len(self._text) < spec.end_len or not self._text.endswith(spec.end):
continue

# Only call handler for newly completed patterns
if not already_processed and type in self._handlers:
try:
await self._handlers[type](pattern_match)
except Exception as e:
logger.error(f"Error in pattern handler for {type}: {e}")
open_idx = None
for i in range(len(self._open) - 1, -1, -1):
if self._open[i].spec.type == spec.type:
open_idx = i
break
if open_idx is None:
continue

# Handle pattern based on action
if action == MatchAction.REMOVE:
# Remove patterns are only removed once (when newly completed)
if not already_processed:
processed_text = processed_text.replace(full_match, "", 1)
else:
# KEEP/AGGREGATE patterns stay in all_matches
all_matches.append(pattern_match)
start_idx = self._open[open_idx].start_idx
if start_idx < 0 or start_idx > len(self._text) - spec.end_len:
continue
if not self._text.startswith(spec.start, start_idx):
continue

return all_matches, processed_text
end_idx = len(self._text)
full_match = self._text[start_idx:end_idx]
content = full_match[spec.start_len : -spec.end_len]
pm = PatternMatch(content=content.strip(), type=spec.type, full_match=full_match)

def _match_start_of_pattern(self, text: str) -> Optional[Tuple[int, dict]]:
"""Check if text contains incomplete pattern pairs.
del self._open[open_idx:]

Determines whether the text contains any start patterns without
matching end patterns, which would indicate incomplete content.
handler = self._handlers.get(spec.type)
if handler is not None:
try:
await handler(pm)
except Exception as e:
logger.error(f"Error in pattern handler for {spec.type}: {e}")

Args:
text: The text to check for incomplete patterns.
if spec.action in (MatchAction.REMOVE, MatchAction.AGGREGATE):
self._text = self._text[:start_idx]

Returns:
A tuple of (start_index, pattern_info) if an incomplete pattern is found,
or None if no patterns are found or all patterns are complete.
"""
for type, pattern_info in self._patterns.items():
start = pattern_info["start"]
end = pattern_info["end"]

# Count occurrences
start_count = text.count(start)
end_count = text.count(end)

# If there are more starts than ends, we have incomplete patterns
# Again, this is written generically but there only ever should
# be one pattern active at a time, so the counts should be 0 or 1.
# Which is why we base the return on the first found.
if start_count > end_count:
start_index = text.find(start)
return [start_index, pattern_info]
return spec, pm

return None

Expand All @@ -327,58 +335,53 @@ async def aggregate(self, text: str) -> AsyncIterator[PatternMatch]:
Yields:
PatternMatch objects as patterns complete or sentences are detected.
"""
# Process text character by character
if not self._patterns and not self._open:
async for aggr in super().aggregate(text):
yield PatternMatch(content=aggr.text, type=aggr.type, full_match=aggr.text)
return

for char in text:
self._text += char
yielded_aggregate = False

# Process any newly complete patterns in the buffer
# Only patterns that complete after _last_processed_position will trigger handlers
patterns, processed_text = await self._process_complete_patterns(
self._text, self._last_processed_position
)
ignore_spec = None
while True:
closed = await self._close_one_if_end_delimiter()
if closed is None:
break
spec, pm = closed

# Update the last processed position to prevent re-processing patterns
# This tracks where in the buffer we've already called handlers, so we
# only trigger handlers once when a pattern completes
self._last_processed_position = len(self._text)

self._text = processed_text

if len(patterns) > 0:
if len(patterns) > 1:
logger.warning(
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
)
# If the pattern found is set to be aggregated, return it
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
if action == MatchAction.AGGREGATE:
self._text = ""
yield patterns[0]
continue

# Check if we have incomplete patterns
pattern_start = self._match_start_of_pattern(self._text)
if pattern_start is not None:
# If the start pattern is at the beginning or should not be separately aggregated, continue
if (
pattern_start[0] == 0
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
):
continue
# For AGGREGATE patterns: yield any text before the pattern starts
# This ensures text doesn't get stuck in the buffer waiting for sentence
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
result = self._text[: pattern_start[0]]
self._text = self._text[pattern_start[0] :]
yield PatternMatch(
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
)
if spec.action == MatchAction.KEEP:
ignore_spec = spec

if spec.action == MatchAction.AGGREGATE:
yield pm
yielded_aggregate = True
break

if yielded_aggregate:
continue

was_open = bool(self._open)
opened = self._push_open_if_start_delimiter(ignore_spec=ignore_spec)

if (
opened is not None
and not was_open
and opened.spec.action == MatchAction.AGGREGATE
and opened.start_idx > 0
):
prefix = self._text[: opened.start_idx]
self._text = self._text[opened.start_idx :]
opened.start_idx = 0
yield PatternMatch(content=prefix.strip(), type=AggregationType.SENTENCE, full_match=prefix)
continue

if self._open:
continue

# Use parent's lookahead logic for sentence detection
aggregation = await super()._check_sentence_with_lookahead(char)
if aggregation:
# Convert to PatternMatch for consistency with return type
yield PatternMatch(
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
)
Expand All @@ -390,8 +393,7 @@ async def handle_interruption(self):
to reset the state and discard any partially aggregated text.
"""
await super().handle_interruption()
self._last_processed_position = 0
# Pattern and handler state persists across interruptions
self._open.clear()

async def reset(self):
"""Clear the internally aggregated text.
Expand All @@ -400,5 +402,4 @@ async def reset(self):
buffered text and clearing pattern tracking state.
"""
await super().reset()
self._last_processed_position = 0
# Pattern and handler state persists across resets
self._open.clear()