Skip to content

Commit 9e50cdd

Browse files
committed
large cleanup
1 parent 5f001e0 commit 9e50cdd

File tree

4 files changed

+33
-250
lines changed

4 files changed

+33
-250
lines changed
Lines changed: 13 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import abc
22
import logging
3-
import time
43
import uuid
5-
from typing import Optional, Dict, Any, Tuple, List, Union
4+
from typing import Optional, Dict, Any, Union
65
from getstream.video.rtc.track_util import PcmData
76

87
from ..edge.types import Participant
9-
from vision_agents.core.events import (
10-
PluginInitializedEvent,
11-
PluginClosedEvent,
12-
)
138
from vision_agents.core.events.manager import EventManager
149
from . import events
1510

@@ -20,84 +15,25 @@ class STT(abc.ABC):
2015
"""
2116
Abstract base class for Speech-to-Text implementations.
2217
23-
This class provides a standardized interface for STT implementations with consistent
24-
event emission patterns and error handling.
18+
Subclasses implement this and have to call
19+
- _emit_partial_transcript_event
20+
- _emit_transcript_event
21+
- _emit_error_event for temporary errors
2522
26-
Events:
27-
- transcript: Emitted when a complete transcript is available.
28-
Args: text (str), user_metadata (dict), metadata (dict)
29-
- partial_transcript: Emitted when a partial transcript is available.
30-
Args: text (str), user_metadata (dict), metadata (dict)
31-
- error: Emitted when an error occurs during transcription.
32-
Args: error (Exception)
33-
34-
Standard Error Handling:
35-
- All implementations should catch exceptions in _process_audio_impl and emit error events
36-
- Use _emit_error_event() helper for consistent error emission
37-
- Log errors with appropriate context using the logger
38-
39-
Standard Event Emission:
40-
- Use _emit_transcript_event() and _emit_partial_transcript_event() helpers
41-
- Include processing time and audio duration in metadata when available
42-
- Maintain consistent metadata structure across implementations
23+
process_audio is currently called every 20ms. The integration with turn keeping could be improved
4324
"""
25+
closed: bool = False
4426

4527
def __init__(
4628
self,
47-
sample_rate: int = 16000,
48-
*,
4929
provider_name: Optional[str] = None,
5030
):
51-
"""
52-
Initialize the STT service.
53-
54-
Args:
55-
sample_rate: The sample rate of the audio to process, in Hz.
56-
provider_name: Name of the STT provider (e.g., "deepgram", "moonshine")
57-
"""
58-
59-
self._track = None
60-
self.sample_rate = sample_rate
61-
self._is_closed = False
6231
self.session_id = str(uuid.uuid4())
6332
self.provider_name = provider_name or self.__class__.__name__
33+
6434
self.events = EventManager()
6535
self.events.register_events_from_module(events, ignore_not_compatible=True)
6636

67-
self.events.send(PluginInitializedEvent(
68-
session_id=self.session_id,
69-
plugin_name=self.provider_name,
70-
plugin_type="STT",
71-
provider=self.provider_name,
72-
configuration={"sample_rate": sample_rate},
73-
))
74-
75-
def _validate_pcm_data(self, pcm_data: PcmData) -> bool:
76-
"""
77-
Validate PCM data input for processing.
78-
79-
Args:
80-
pcm_data: The PCM audio data to validate.
81-
82-
Returns:
83-
True if the data is valid, False otherwise.
84-
"""
85-
86-
if not hasattr(pcm_data, "samples") or pcm_data.samples is None:
87-
logger.warning("PCM data has no samples")
88-
return False
89-
90-
if not hasattr(pcm_data, "sample_rate") or pcm_data.sample_rate <= 0:
91-
logger.warning("PCM data has invalid sample rate")
92-
return False
93-
94-
# Check if samples are empty
95-
if hasattr(pcm_data.samples, "__len__") and len(pcm_data.samples) == 0:
96-
logger.debug("Received empty audio samples")
97-
return False
98-
99-
return True
100-
10137
def _emit_transcript_event(
10238
self,
10339
text: str,
@@ -159,12 +95,8 @@ def _emit_error_event(
15995
user_metadata: Optional[Union[Dict[str, Any], Participant]] = None,
16096
):
16197
"""
162-
Emit an error event with structured data.
163-
164-
Args:
165-
error: The exception that occurred.
166-
context: Additional context about where the error occurred.
167-
user_metadata: User-specific metadata.
98+
Emit an error event. Note this should only be emitted for temporary errors.
99+
Permanent errors due to config etc should be directly raised
168100
"""
169101
self.events.send(events.STTErrorEvent(
170102
session_id=self.session_id,
@@ -176,114 +108,11 @@ def _emit_error_event(
176108
is_recoverable=not isinstance(error, (SystemExit, KeyboardInterrupt)),
177109
))
178110

111+
@abc.abstractmethod
179112
async def process_audio(
180-
self, pcm_data: PcmData, participant: Optional[Participant] = None
113+
self, pcm_data: PcmData, participant: Optional[Participant] = None,
181114
):
182-
"""
183-
Process audio data for transcription and emit appropriate events.
184-
185-
Args:
186-
pcm_data: The PCM audio data to process.
187-
user_metadata: Additional metadata about the user or session.
188-
"""
189-
if self._is_closed:
190-
logger.debug("Ignoring audio processing request - STT is closed")
191-
return
192-
193-
# Validate input data
194-
if not self._validate_pcm_data(pcm_data):
195-
logger.warning("Invalid PCM data received, skipping processing")
196-
return
197-
198-
try:
199-
# Process the audio data using the implementation-specific method
200-
audio_duration_ms = (
201-
pcm_data.duration * 1000 if hasattr(pcm_data, "duration") else None
202-
)
203-
logger.debug(
204-
"Processing audio chunk",
205-
extra={
206-
"duration_ms": audio_duration_ms,
207-
"has_user_metadata": participant is not None,
208-
},
209-
)
210-
211-
start_time = time.time()
212-
results = await self._process_audio_impl(pcm_data, participant)
213-
processing_time = time.time() - start_time
214-
215-
# If no results were returned, just return
216-
if not results:
217-
logger.debug(
218-
"No speech detected in audio",
219-
extra={
220-
"processing_time_ms": processing_time * 1000,
221-
"audio_duration_ms": audio_duration_ms,
222-
},
223-
)
224-
return
225-
226-
# Process each result and emit the appropriate event
227-
for is_final, text, metadata in results:
228-
# Ensure metadata includes processing time if not already present
229-
if "processing_time_ms" not in metadata:
230-
metadata["processing_time_ms"] = processing_time * 1000
231-
232-
if is_final:
233-
self._emit_transcript_event(text, participant, metadata)
234-
else:
235-
self._emit_partial_transcript_event(text, participant, metadata)
236-
237-
except Exception as e:
238-
# Emit any errors that occur during processing
239-
self._emit_error_event(e, "audio processing", participant)
240-
241-
@abc.abstractmethod
242-
async def _process_audio_impl(
243-
self, pcm_data: PcmData, user_metadata: Optional[Union[Dict[str, Any], Participant]] = None
244-
) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]:
245-
"""
246-
Implementation-specific method to process audio data.
247-
248-
This method must be implemented by all STT providers and should handle the core
249-
transcription logic. The base class handles event emission and error handling.
250-
251-
Args:
252-
pcm_data: The PCM audio data to process. Guaranteed to be valid by base class.
253-
user_metadata: Additional metadata about the user or session.
254-
255-
Returns:
256-
optional list[tuple[bool, str, dict]] | None
257-
• synchronous providers: a list of results.
258-
• asynchronous providers: None (they emit events themselves).
259-
260-
Notes:
261-
Implementations must not both emit events and return non-empty results,
262-
or duplicate events will be produced.
263-
Exceptions should bubble up; process_audio() will catch them
264-
and emit a single "error" event.
265-
"""
266115
pass
267116

268-
@abc.abstractmethod
269117
async def close(self):
270-
"""
271-
Close the STT service and release any resources.
272-
273-
Implementations should:
274-
- Set self._is_closed = True
275-
- Clean up any background tasks or connections
276-
- Release any allocated resources
277-
- Log the closure appropriately
278-
"""
279-
if not self._is_closed:
280-
self._is_closed = True
281-
282-
# Emit closure event
283-
self.events.send(PluginClosedEvent(
284-
session_id=self.session_id,
285-
plugin_name=self.provider_name,
286-
plugin_type="STT",
287-
provider=self.provider_name,
288-
cleanup_successful=True,
289-
))
118+
self.closed = True
File renamed without changes.

plugins/deepgram/vision_agents/plugins/deepgram/stt.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import time
6-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
6+
from typing import TYPE_CHECKING, Any, Dict, Optional
77

88
import numpy as np
99
import websockets
@@ -23,8 +23,7 @@
2323

2424
from .utils import generate_silence
2525

26-
if TYPE_CHECKING:
27-
from vision_agents.core.edge.types import Participant
26+
from vision_agents.core.edge.types import Participant
2827

2928
logger = logging.getLogger(__name__)
3029

@@ -50,7 +49,6 @@ def __init__(
5049
self,
5150
api_key: Optional[str] = None,
5251
options: Optional[dict] = None,
53-
sample_rate: int = 48000,
5452
language: str = "en-US",
5553
interim_results: bool = True,
5654
client: Optional[AsyncDeepgramClient] = None,
@@ -70,7 +68,7 @@ def __init__(
7068
connection_timeout: Time to wait for the Deepgram connection to be established.
7169
7270
"""
73-
super().__init__(sample_rate=sample_rate)
71+
super().__init__(provider_name="deepgram")
7472

7573
# If no API key was provided, check for DEEPGRAM_API_KEY in environment
7674
if api_key is None:
@@ -86,12 +84,13 @@ def __init__(
8684
client if client is not None else AsyncDeepgramClient(api_key=api_key)
8785
)
8886
self.dg_connection: Optional[AsyncV1SocketClient] = None
87+
self.sample_rate = 48000
8988

9089
self.options = options or {
9190
"model": "nova-2",
9291
"language": language,
9392
"encoding": "linear16",
94-
"sample_rate": sample_rate,
93+
"sample_rate": self.sample_rate,
9594
"channels": 1,
9695
"interim_results": interim_results,
9796
}
@@ -101,7 +100,7 @@ def __init__(
101100

102101
# Generate a silence audio to use as keep-alive message
103102
self._keep_alive_data = generate_silence(
104-
sample_rate=sample_rate, duration_ms=10
103+
sample_rate=self.sample_rate, duration_ms=10
105104
)
106105
self._keep_alive_interval = keep_alive_interval
107106

@@ -121,7 +120,7 @@ async def start(self):
121120
"""
122121
Start the main task establishing the Deepgram connection and processing the events.
123122
"""
124-
if self._is_closed:
123+
if self.closed:
125124
logger.warning("Cannot setup connection - Deepgram instance is closed")
126125
return None
127126

@@ -178,15 +177,8 @@ async def started(self):
178177
)
179178

180179
async def close(self):
180+
await super().close()
181181
"""Close the Deepgram connection and clean up resources."""
182-
if self._is_closed:
183-
logger.debug("Deepgram STT service already closed")
184-
return
185-
186-
logger.info("Closing Deepgram STT service")
187-
self._is_closed = True
188-
189-
# Close the Deepgram connection if it exists
190182
if self.dg_connection:
191183
logger.debug("Closing Deepgram connection")
192184
try:
@@ -261,29 +253,15 @@ async def _on_connection_close(self, message: Any):
261253
logger.warning(f"Deepgram connection closed. message={message}")
262254
await self.close()
263255

264-
async def _process_audio_impl(
256+
async def process_audio(
265257
self,
266258
pcm_data: PcmData,
267-
user_metadata: Optional[Union[Dict[str, Any], "Participant"]] = None,
268-
) -> Optional[List[Tuple[bool, str, Dict[str, Any]]]]:
269-
"""
270-
Process audio data through Deepgram for transcription.
271-
272-
Args:
273-
pcm_data: The PCM audio data to process.
274-
user_metadata: Additional metadata about the user or session.
275-
276-
Returns:
277-
None - Deepgram operates in asynchronous mode and emits events directly
278-
when transcripts arrive from the streaming service.
279-
"""
280-
if self._is_closed:
259+
participant: Optional[Participant] = None,
260+
):
261+
if self.closed:
281262
logger.warning("Deepgram connection is closed, ignoring audio")
282263
return None
283264

284-
# Store the current user context for transcript events
285-
self._current_user = user_metadata # type: ignore[assignment]
286-
287265
# Check if the input sample rate matches the expected sample rate
288266
if pcm_data.sample_rate != self.sample_rate:
289267
logger.warning(
@@ -334,7 +312,7 @@ async def _keepalive_loop(self):
334312
Send the silence audio every `interval` seconds
335313
to prevent Deepgram from closing the connection.
336314
"""
337-
while not self._is_closed and self.dg_connection is not None:
315+
while not self.closed and self.dg_connection is not None:
338316
if self._last_sent_at + self._keep_alive_interval <= time.time():
339317
logger.debug("Sending keepalive packet to Deepgram...")
340318
# Send audio silence to keep the connection open

0 commit comments

Comments
 (0)