Skip to content

Commit 538f373

Browse files
MadhavShroffclaude
andcommitted
Add async gRPC ASR client with grpc.aio support
Adds AsyncAuth and ASRServiceAsync classes for native async/await speech recognition, enabling efficient high-concurrency scenarios without thread overhead. Features: - AsyncAuth: Async channel management with SSL/mTLS support - ASRServiceAsync: Streaming and batch recognition methods - Double-checked locking for lock-free fast paths - Async file I/O for SSL certificate loading - Comprehensive test coverage (26 tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 706053a commit 538f373

File tree

3 files changed

+993
-0
lines changed

3 files changed

+993
-0
lines changed

riva/client/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,6 @@
4141
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig
4242
from riva.client.tts import SpeechSynthesisService
4343
from riva.client.nmt import NeuralMachineTranslationClient
44+
45+
# Async extensions (grpc.aio)
46+
from riva.client.asr_async import ASRServiceAsync, AsyncAuth

riva/client/asr_async.py

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
"""Async ASR client using grpc.aio.
5+
6+
This module provides async/await support for Riva ASR streaming,
7+
enabling efficient high-concurrency scenarios without thread overhead.
8+
9+
Example:
10+
async with AsyncAuth(uri="localhost:50051") as auth:
11+
service = ASRServiceAsync(auth)
12+
async for response in service.streaming_recognize(audio_gen, config):
13+
print(response.results)
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import asyncio
19+
from typing import AsyncIterator, Sequence
20+
21+
import grpc
22+
import grpc.aio
23+
24+
from riva.client.proto import riva_asr_pb2 as rasr
25+
from riva.client.proto import riva_asr_pb2_grpc as rasr_srv
26+
27+
__all__ = ["AsyncAuth", "ASRServiceAsync"]
28+
29+
30+
class AsyncAuth:
31+
"""Async-compatible authentication and channel management.
32+
33+
Provides lazy channel creation with thread-safe initialization.
34+
Supports both insecure and SSL connections.
35+
36+
Args:
37+
uri: Riva server address (host:port)
38+
use_ssl: Enable SSL/TLS
39+
ssl_root_cert: Path to root CA certificate (optional)
40+
ssl_client_cert: Path to client certificate for mTLS (optional)
41+
ssl_client_key: Path to client private key for mTLS (optional)
42+
metadata: List of (key, value) tuples for request metadata
43+
options: Additional gRPC channel options
44+
45+
Example:
46+
# Simple insecure connection
47+
auth = AsyncAuth(uri="localhost:50051")
48+
49+
# SSL with custom cert
50+
auth = AsyncAuth(uri="riva.example.com:443", use_ssl=True)
51+
52+
# With API key metadata
53+
auth = AsyncAuth(
54+
uri="riva.example.com:443",
55+
use_ssl=True,
56+
metadata=[("x-api-key", "your-key")]
57+
)
58+
59+
# As context manager (recommended)
60+
async with AsyncAuth(uri="localhost:50051") as auth:
61+
service = ASRServiceAsync(auth)
62+
# use service...
63+
"""
64+
65+
# Default channel options for real-time streaming
66+
DEFAULT_OPTIONS: Sequence[tuple[str, int | bool]] = (
67+
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
68+
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
69+
("grpc.keepalive_time_ms", 10_000), # 10 sec
70+
("grpc.keepalive_timeout_ms", 5_000), # 5 sec
71+
("grpc.keepalive_permit_without_calls", True),
72+
("grpc.http2.min_ping_interval_without_data_ms", 5_000),
73+
)
74+
75+
def __init__(
76+
self,
77+
uri: str,
78+
use_ssl: bool = False,
79+
ssl_root_cert: str | None = None,
80+
ssl_client_cert: str | None = None,
81+
ssl_client_key: str | None = None,
82+
metadata: Sequence[tuple[str, str]] | None = None,
83+
options: Sequence[tuple[str, int | bool | str]] | None = None,
84+
) -> None:
85+
self.uri = uri
86+
self.use_ssl = use_ssl
87+
self.ssl_root_cert = ssl_root_cert
88+
self.ssl_client_cert = ssl_client_cert
89+
self.ssl_client_key = ssl_client_key
90+
self.metadata = list(metadata) if metadata else []
91+
self._options = list(options) if options else list(self.DEFAULT_OPTIONS)
92+
93+
self._channel: grpc.aio.Channel | None = None
94+
self._lock = asyncio.Lock()
95+
96+
async def get_channel(self) -> grpc.aio.Channel:
97+
"""Get or create the async gRPC channel.
98+
99+
Thread-safe: uses asyncio.Lock to ensure single channel creation
100+
even under concurrent access. Uses double-checked locking for
101+
fast-path optimization when channel already exists.
102+
103+
Returns:
104+
The async gRPC channel
105+
"""
106+
# Fast path: channel already exists
107+
if self._channel is not None:
108+
return self._channel
109+
# Slow path: acquire lock and create channel
110+
async with self._lock:
111+
if self._channel is None:
112+
self._channel = await self._create_channel()
113+
return self._channel
114+
115+
async def _create_channel(self) -> grpc.aio.Channel:
116+
"""Create the appropriate channel type based on SSL settings."""
117+
if self.use_ssl:
118+
credentials = await self._create_ssl_credentials()
119+
return grpc.aio.secure_channel(
120+
self.uri,
121+
credentials,
122+
options=self._options,
123+
)
124+
else:
125+
return grpc.aio.insecure_channel(
126+
self.uri,
127+
options=self._options,
128+
)
129+
130+
async def _create_ssl_credentials(self) -> grpc.ChannelCredentials:
131+
"""Create SSL credentials from certificate files.
132+
133+
Uses asyncio.to_thread() for non-blocking file I/O.
134+
"""
135+
136+
def _read_file(path: str) -> bytes:
137+
with open(path, "rb") as f:
138+
return f.read()
139+
140+
root_cert = None
141+
client_cert = None
142+
client_key = None
143+
144+
if self.ssl_root_cert:
145+
root_cert = await asyncio.to_thread(_read_file, self.ssl_root_cert)
146+
147+
if self.ssl_client_cert:
148+
client_cert = await asyncio.to_thread(_read_file, self.ssl_client_cert)
149+
150+
if self.ssl_client_key:
151+
client_key = await asyncio.to_thread(_read_file, self.ssl_client_key)
152+
153+
return grpc.ssl_channel_credentials(
154+
root_certificates=root_cert,
155+
private_key=client_key,
156+
certificate_chain=client_cert,
157+
)
158+
159+
def get_auth_metadata(self) -> list[tuple[str, str]]:
160+
"""Get metadata to include with RPC calls.
161+
162+
Returns:
163+
List of (key, value) metadata tuples
164+
"""
165+
return self.metadata
166+
167+
async def close(self) -> None:
168+
"""Close the channel and release resources."""
169+
async with self._lock:
170+
if self._channel is not None:
171+
await self._channel.close()
172+
self._channel = None
173+
174+
async def __aenter__(self) -> "AsyncAuth":
175+
"""Async context manager entry."""
176+
return self
177+
178+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
179+
"""Async context manager exit - ensures cleanup."""
180+
await self.close()
181+
182+
183+
class ASRServiceAsync:
184+
"""Async ASR service using grpc.aio.
185+
186+
Provides async streaming and batch recognition methods that can handle
187+
many concurrent streams without thread overhead.
188+
189+
Args:
190+
auth: AsyncAuth instance for channel management
191+
192+
Example:
193+
auth = AsyncAuth(uri="localhost:50051")
194+
service = ASRServiceAsync(auth)
195+
196+
# Streaming recognition
197+
async def audio_generator():
198+
while audio_available:
199+
yield audio_chunk
200+
201+
async for response in service.streaming_recognize(
202+
audio_generator(),
203+
streaming_config
204+
):
205+
for result in response.results:
206+
print(result.alternatives[0].transcript)
207+
208+
await auth.close()
209+
"""
210+
211+
def __init__(self, auth: AsyncAuth) -> None:
212+
self.auth = auth
213+
self._stub: "rasr_srv.RivaSpeechRecognitionStub | None" = None
214+
self._stub_lock = asyncio.Lock()
215+
# Cache metadata reference to avoid repeated method calls
216+
self._metadata = auth.get_auth_metadata() or None
217+
218+
async def _get_stub(self) -> "rasr_srv.RivaSpeechRecognitionStub":
219+
"""Get or create the gRPC stub.
220+
221+
Thread-safe stub creation with double-checked locking for
222+
fast-path optimization when stub already exists.
223+
"""
224+
# Fast path: stub already exists
225+
if self._stub is not None:
226+
return self._stub
227+
# Slow path: acquire lock and create stub
228+
async with self._stub_lock:
229+
if self._stub is None:
230+
channel = await self.auth.get_channel()
231+
self._stub = rasr_srv.RivaSpeechRecognitionStub(channel)
232+
return self._stub
233+
234+
async def streaming_recognize(
235+
self,
236+
audio_chunks: AsyncIterator[bytes],
237+
streaming_config: "rasr.StreamingRecognitionConfig",
238+
) -> AsyncIterator["rasr.StreamingRecognizeResponse"]:
239+
"""Perform async streaming speech recognition.
240+
241+
This is the primary method for real-time speech recognition.
242+
Audio is streamed to the server and partial/final results are
243+
yielded as they become available.
244+
245+
Args:
246+
audio_chunks: Async iterator yielding raw audio bytes
247+
(LINEAR_PCM format recommended, 16-bit, mono)
248+
streaming_config: Configuration including sample rate,
249+
language, and interim_results setting
250+
251+
Yields:
252+
StreamingRecognizeResponse objects containing transcription
253+
results. Check result.is_final to distinguish partial from
254+
final results.
255+
256+
Raises:
257+
grpc.aio.AioRpcError: On gRPC communication errors
258+
259+
Example:
260+
config = StreamingRecognitionConfig(
261+
config=RecognitionConfig(
262+
encoding=AudioEncoding.LINEAR_PCM,
263+
sample_rate_hertz=16000,
264+
language_code="en-US",
265+
),
266+
interim_results=True,
267+
)
268+
269+
async for response in service.streaming_recognize(
270+
audio_generator(), config
271+
):
272+
for result in response.results:
273+
transcript = result.alternatives[0].transcript
274+
if result.is_final:
275+
print(f"Final: {transcript}")
276+
else:
277+
print(f"Partial: {transcript}")
278+
"""
279+
stub = await self._get_stub()
280+
metadata = self._metadata
281+
282+
async def request_generator() -> AsyncIterator[rasr.StreamingRecognizeRequest]:
283+
# First request: config only (no audio)
284+
yield rasr.StreamingRecognizeRequest(streaming_config=streaming_config)
285+
# Subsequent requests: audio only
286+
async for chunk in audio_chunks:
287+
yield rasr.StreamingRecognizeRequest(audio_content=chunk)
288+
289+
call = stub.StreamingRecognize(
290+
request_generator(),
291+
metadata=metadata,
292+
)
293+
294+
async for response in call:
295+
yield response
296+
297+
async def recognize(
298+
self,
299+
audio_bytes: bytes,
300+
config: "rasr.RecognitionConfig",
301+
) -> "rasr.RecognizeResponse":
302+
"""Perform async batch (offline) speech recognition.
303+
304+
Use this for complete audio files rather than streaming.
305+
306+
Args:
307+
audio_bytes: Complete audio data
308+
config: Recognition configuration
309+
310+
Returns:
311+
RecognizeResponse with transcription results
312+
313+
Raises:
314+
grpc.aio.AioRpcError: On gRPC communication errors
315+
"""
316+
stub = await self._get_stub()
317+
metadata = self._metadata
318+
319+
request = rasr.RecognizeRequest(config=config, audio=audio_bytes)
320+
return await stub.Recognize(request, metadata=metadata)
321+
322+
async def get_config(self) -> "rasr.RivaSpeechRecognitionConfigResponse":
323+
"""Get the server's speech recognition configuration.
324+
325+
Returns:
326+
Configuration response with available models and settings
327+
"""
328+
stub = await self._get_stub()
329+
metadata = self._metadata
330+
331+
request = rasr.RivaSpeechRecognitionConfigRequest()
332+
return await stub.GetRivaSpeechRecognitionConfig(request, metadata=metadata)

0 commit comments

Comments
 (0)