Skip to content

Commit bdc6fa1

Browse files
committed
feat(azure-stt): add cancellation tracing and session guards
1 parent 5e308c3 commit bdc6fa1

File tree

4 files changed

+212
-35
lines changed

4 files changed

+212
-35
lines changed

changelog/3884.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
- Added Azure STT cancellation tracing attributes and session termination guards so canceled recognition sessions surface structured observability data and stop accepting audio as if still healthy.

src/pipecat/services/azure/stt.py

Lines changed: 107 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pipecat.services.stt_service import STTService
3232
from pipecat.transcriptions.language import Language
3333
from pipecat.utils.time import time_now_iso8601
34-
from pipecat.utils.tracing.service_decorators import traced_stt
34+
from pipecat.utils.tracing.service_decorators import trace_stt_cancellation, traced_stt
3535

3636
try:
3737
from azure.cognitiveservices.speech import (
@@ -123,6 +123,10 @@ def __init__(
123123

124124
self._audio_stream = None
125125
self._speech_recognizer = None
126+
self._audio_sent = False
127+
self._recognition_active = False
128+
self._recognition_terminated = False
129+
self._shutdown_requested = False
126130

127131
def can_generate_metrics(self) -> bool:
128132
"""Check if this service can generate performance metrics.
@@ -179,7 +183,12 @@ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
179183
try:
180184
await self.start_processing_metrics()
181185
if self._audio_stream:
186+
if self._recognition_terminated and not self._shutdown_requested:
187+
logger.warning("Azure STT recognition terminated, dropping audio chunk")
188+
yield None
189+
return
182190
self._audio_stream.write(audio)
191+
self._audio_sent = True
183192
yield None
184193
except Exception as e:
185194
yield ErrorFrame(error=f"Unknown error occurred: {e}")
@@ -198,6 +207,11 @@ async def start(self, frame: StartFrame):
198207
if self._audio_stream:
199208
return
200209

210+
self._audio_sent = False
211+
self._recognition_active = False
212+
self._recognition_terminated = False
213+
self._shutdown_requested = False
214+
201215
try:
202216
stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
203217
self._audio_stream = PushAudioInputStream(stream_format)
@@ -230,6 +244,10 @@ async def stop(self, frame: EndFrame):
230244
"""
231245
await super().stop(frame)
232246

247+
self._shutdown_requested = True
248+
self._recognition_active = False
249+
self._recognition_terminated = True
250+
233251
if self._speech_recognizer:
234252
self._speech_recognizer.stop_continuous_recognition_async()
235253

@@ -246,6 +264,10 @@ async def cancel(self, frame: CancelFrame):
246264
"""
247265
await super().cancel(frame)
248266

267+
self._shutdown_requested = True
268+
self._recognition_active = False
269+
self._recognition_terminated = True
270+
249271
if self._speech_recognizer:
250272
self._speech_recognizer.stop_continuous_recognition_async()
251273

@@ -259,6 +281,25 @@ async def _handle_transcription(
259281
"""Handle a transcription result with tracing."""
260282
await self.stop_processing_metrics()
261283

284+
async def _trace_cancellation(
285+
self,
286+
*,
287+
reason: str,
288+
code: str,
289+
recoverable: bool,
290+
phase: str,
291+
):
292+
"""Record a trace span for a canceled Azure STT recognition."""
293+
trace_stt_cancellation(
294+
self,
295+
error_type="azure.stt.canceled",
296+
cancel_reason=reason,
297+
cancel_code=code,
298+
recoverable=recoverable,
299+
phase=phase,
300+
region=self._settings.region if isinstance(self._settings.region, str) else None,
301+
)
302+
262303
def _on_handle_recognized(self, event):
263304
if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
264305
language = getattr(event.result, "language", None) or self._settings.language
@@ -288,30 +329,87 @@ def _on_handle_recognizing(self, event):
288329

289330
def _on_handle_canceled(self, event):
290331
details = getattr(event, "cancellation_details", None)
291-
reason = getattr(details, "reason", "UNKNOWN")
292-
code = getattr(details, "code", "UNKNOWN")
332+
reason = self._normalize_cancellation_value(getattr(details, "reason", "UNKNOWN"))
333+
code = self._normalize_cancellation_value(getattr(details, "code", "UNKNOWN"))
293334
error_details = getattr(details, "error_details", "")
335+
phase = self._get_cancellation_phase()
336+
recoverable = self._is_cancellation_recoverable(reason, code)
337+
338+
self._recognition_active = False
339+
self._recognition_terminated = True
294340

295341
logger.error(
296-
"Azure STT recognition canceled: reason={}, code={}, details={}",
342+
"Azure STT recognition canceled: reason={}, code={}, phase={}, recoverable={}, details={}",
297343
reason,
298344
code,
345+
phase,
346+
recoverable,
299347
error_details,
300348
)
301349

302-
error_message = f"Azure STT recognition canceled: {code} - {error_details}"
350+
asyncio.run_coroutine_threadsafe(
351+
self._trace_cancellation(
352+
reason=reason,
353+
code=code,
354+
recoverable=recoverable,
355+
phase=phase,
356+
),
357+
self.get_event_loop(),
358+
)
359+
360+
error_message = f"Azure STT recognition canceled: {reason} ({code})"
303361
asyncio.run_coroutine_threadsafe(
304362
self.push_error(error_msg=error_message), self.get_event_loop()
305363
)
306364

307365
def _on_handle_session_started(self, event):
366+
self._recognition_active = True
367+
self._recognition_terminated = False
308368
logger.info(
309369
"Azure STT session started: session_id={}",
310370
getattr(event, "session_id", "unknown"),
311371
)
312372

313373
def _on_handle_session_stopped(self, event):
314-
logger.warning(
315-
"Azure STT session stopped: session_id={}",
316-
getattr(event, "session_id", "unknown"),
317-
)
374+
self._recognition_active = False
375+
self._recognition_terminated = True
376+
if self._shutdown_requested:
377+
logger.info(
378+
"Azure STT session stopped during shutdown: session_id={}",
379+
getattr(event, "session_id", "unknown"),
380+
)
381+
else:
382+
logger.warning(
383+
"Azure STT session stopped: session_id={}",
384+
getattr(event, "session_id", "unknown"),
385+
)
386+
387+
@staticmethod
388+
def _normalize_cancellation_value(value: Any) -> str:
389+
normalized = getattr(value, "name", None)
390+
if normalized:
391+
return normalized
392+
return str(value)
393+
394+
def _get_cancellation_phase(self) -> str:
395+
if self._shutdown_requested:
396+
return "shutdown"
397+
if not self._recognition_active and not self._audio_sent:
398+
return "startup"
399+
return "streaming"
400+
401+
@staticmethod
402+
def _is_cancellation_recoverable(reason: str, code: str) -> bool:
403+
if reason == "CancelledByUser":
404+
return True
405+
if reason != "Error":
406+
return False
407+
408+
return code in {
409+
"ConnectionFailure",
410+
"ServiceRedirectPermanent",
411+
"ServiceRedirectTemporary",
412+
"ServiceTimeout",
413+
"ServiceUnavailable",
414+
"TooManyRequests",
415+
}

src/pipecat/utils/tracing/service_decorators.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,51 @@ async def wrapper(self, transcript, is_final, language=None):
373373
return decorator
374374

375375

376+
def trace_stt_cancellation(
377+
service,
378+
*,
379+
error_type: str,
380+
cancel_reason: str,
381+
cancel_code: str,
382+
recoverable: bool,
383+
phase: str,
384+
region: Optional[str] = None,
385+
) -> None:
386+
"""Create a trace span for STT cancellation events.
387+
388+
Args:
389+
service: STT service instance generating the cancellation.
390+
error_type: Stable error classification.
391+
cancel_reason: Provider cancellation reason.
392+
cancel_code: Provider cancellation code.
393+
recoverable: Whether the application should attempt recovery.
394+
phase: Service lifecycle phase where cancellation happened.
395+
region: Cloud region associated with the service, if known.
396+
"""
397+
if not is_tracing_available() or not getattr(service, "_tracing_enabled", False):
398+
return
399+
400+
service_class_name = service.__class__.__name__
401+
parent_context = _get_turn_context(service) or _get_parent_service_context(service)
402+
403+
tracer = trace.get_tracer("pipecat")
404+
with tracer.start_as_current_span("stt.cancel", context=parent_context) as current_span:
405+
current_span.set_attribute(
406+
"gen_ai.system", service_class_name.replace("STTService", "").lower()
407+
)
408+
current_span.set_attribute("gen_ai.operation.name", "stt.cancel")
409+
current_span.set_attribute("error.type", error_type)
410+
current_span.set_attribute("stt.cancel.reason", cancel_reason)
411+
current_span.set_attribute("stt.cancel.code", cancel_code)
412+
current_span.set_attribute("stt.cancel.recoverable", recoverable)
413+
current_span.set_attribute("stt.cancel.phase", phase)
414+
if region:
415+
current_span.set_attribute("cloud.region", region)
416+
417+
if cancel_reason == "Error":
418+
current_span.set_status(trace.Status(trace.StatusCode.ERROR, cancel_code))
419+
420+
376421
def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable:
377422
"""Trace LLM service methods with LLM-specific attributes.
378423

tests/test_azure_stt.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ class ConnectionClosedOK(Exception):
3939
class WebSocketClientProtocol:
4040
pass
4141

42-
protocol_module.State = State
43-
exceptions_module.ConnectionClosedError = ConnectionClosedError
44-
exceptions_module.ConnectionClosedOK = ConnectionClosedOK
45-
websockets_module.protocol = protocol_module
46-
websockets_module.exceptions = exceptions_module
47-
websockets_module.WebSocketClientProtocol = WebSocketClientProtocol
42+
setattr(protocol_module, "State", State)
43+
setattr(exceptions_module, "ConnectionClosedError", ConnectionClosedError)
44+
setattr(exceptions_module, "ConnectionClosedOK", ConnectionClosedOK)
45+
setattr(websockets_module, "protocol", protocol_module)
46+
setattr(websockets_module, "exceptions", exceptions_module)
47+
setattr(websockets_module, "WebSocketClientProtocol", WebSocketClientProtocol)
4848

4949

5050
def _install_azure_speech_stub() -> None:
@@ -118,17 +118,17 @@ class AudioConfig:
118118
def __init__(self, stream=None):
119119
self.stream = stream
120120

121-
speech_module.ResultReason = ResultReason
122-
speech_module.SpeechConfig = SpeechConfig
123-
speech_module.SpeechRecognizer = SpeechRecognizer
124-
audio_module.AudioStreamFormat = AudioStreamFormat
125-
audio_module.PushAudioInputStream = PushAudioInputStream
126-
dialog_module.AudioConfig = AudioConfig
121+
setattr(speech_module, "ResultReason", ResultReason)
122+
setattr(speech_module, "SpeechConfig", SpeechConfig)
123+
setattr(speech_module, "SpeechRecognizer", SpeechRecognizer)
124+
setattr(audio_module, "AudioStreamFormat", AudioStreamFormat)
125+
setattr(audio_module, "PushAudioInputStream", PushAudioInputStream)
126+
setattr(dialog_module, "AudioConfig", AudioConfig)
127127

128-
azure_module.cognitiveservices = cognitiveservices_module
129-
cognitiveservices_module.speech = speech_module
130-
speech_module.audio = audio_module
131-
speech_module.dialog = dialog_module
128+
setattr(azure_module, "cognitiveservices", cognitiveservices_module)
129+
setattr(cognitiveservices_module, "speech", speech_module)
130+
setattr(speech_module, "audio", audio_module)
131+
setattr(speech_module, "dialog", dialog_module)
132132

133133

134134
_install_azure_speech_stub()
@@ -137,16 +137,16 @@ def __init__(self, stream=None):
137137
azure_package_module = types.ModuleType("pipecat.services.azure")
138138
azure_package_module.__path__ = []
139139
common_module = types.ModuleType("pipecat.services.azure.common")
140-
common_module.language_to_azure_language = lambda _language: "en-US"
140+
setattr(common_module, "language_to_azure_language", lambda _language: "en-US")
141141

142142
sys.modules["pipecat.services.azure"] = azure_package_module
143143
sys.modules["pipecat.services.azure.common"] = common_module
144144

145145
stt_file = pathlib.Path(__file__).resolve().parents[1] / "src/pipecat/services/azure/stt.py"
146146
spec = importlib.util.spec_from_file_location("pipecat.services.azure.stt", stt_file)
147+
assert spec is not None and spec.loader is not None
147148
stt_module = importlib.util.module_from_spec(spec)
148149
sys.modules["pipecat.services.azure.stt"] = stt_module
149-
assert spec and spec.loader
150150
spec.loader.exec_module(stt_module)
151151

152152
from pipecat.frames.frames import StartFrame
@@ -192,6 +192,7 @@ def test_canceled_handler_pushes_error_with_details():
192192
service = stt_module.AzureSTTService(api_key="test-key", region="eastus")
193193
service.push_error = AsyncMock()
194194
service.get_event_loop = MagicMock(return_value=MagicMock())
195+
service._trace_cancellation = AsyncMock()
195196

196197
canceled_event = MagicMock()
197198
canceled_event.cancellation_details.reason = "Error"
@@ -202,12 +203,19 @@ def test_canceled_handler_pushes_error_with_details():
202203
service._on_handle_canceled(canceled_event)
203204

204205
service.push_error.assert_called_once_with(
205-
error_msg="Azure STT recognition canceled: AuthenticationFailure - 401 Unauthorized"
206+
error_msg="Azure STT recognition canceled: Error (AuthenticationFailure)"
206207
)
207-
assert run_threadsafe.call_count == 1
208+
service._trace_cancellation.assert_called_once_with(
209+
reason="Error",
210+
code="AuthenticationFailure",
211+
recoverable=False,
212+
phase="startup",
213+
)
214+
assert run_threadsafe.call_count == 2
208215

209-
pending_coroutine = run_threadsafe.call_args.args[0]
210-
pending_coroutine.close()
216+
for call in run_threadsafe.call_args_list:
217+
pending_coroutine = call.args[0]
218+
pending_coroutine.close()
211219

212220

213221
def test_canceled_handler_uses_safe_defaults_when_details_missing():
@@ -216,6 +224,7 @@ def test_canceled_handler_uses_safe_defaults_when_details_missing():
216224
service = stt_module.AzureSTTService(api_key="test-key", region="eastus")
217225
service.push_error = AsyncMock()
218226
service.get_event_loop = MagicMock(return_value=MagicMock())
227+
service._trace_cancellation = AsyncMock()
219228

220229
canceled_event = MagicMock()
221230
canceled_event.cancellation_details = None
@@ -224,9 +233,33 @@ def test_canceled_handler_uses_safe_defaults_when_details_missing():
224233
service._on_handle_canceled(canceled_event)
225234

226235
service.push_error.assert_called_once_with(
227-
error_msg="Azure STT recognition canceled: UNKNOWN - "
236+
error_msg="Azure STT recognition canceled: UNKNOWN (UNKNOWN)"
237+
)
238+
service._trace_cancellation.assert_called_once_with(
239+
reason="UNKNOWN",
240+
code="UNKNOWN",
241+
recoverable=False,
242+
phase="startup",
228243
)
229-
assert run_threadsafe.call_count == 1
244+
assert run_threadsafe.call_count == 2
245+
246+
for call in run_threadsafe.call_args_list:
247+
pending_coroutine = call.args[0]
248+
pending_coroutine.close()
249+
250+
251+
@pytest.mark.asyncio
252+
async def test_run_stt_drops_audio_after_terminated_session():
253+
"""Verify audio is not written after a terminated session."""
254+
255+
service = stt_module.AzureSTTService(api_key="test-key", region="eastus")
256+
service.start_processing_metrics = AsyncMock()
257+
service._audio_stream = MagicMock()
258+
service._recognition_terminated = True
259+
260+
frames = []
261+
async for frame in service.run_stt(b"audio"):
262+
frames.append(frame)
230263

231-
pending_coroutine = run_threadsafe.call_args.args[0]
232-
pending_coroutine.close()
264+
service._audio_stream.write.assert_not_called()
265+
assert frames == [None]

0 commit comments

Comments
 (0)