Skip to content

Commit 4f003cc

Browse files
feat: support passing metadata (#53) (#56)
* feat: support passing metadata * pass credentials via metadata call credentials * Update common proto submodule Co-authored-by: Viraj Karandikar <[email protected]>
1 parent 185e3ff commit 4f003cc

15 files changed

+31
-17
lines changed

riva/client/argparse_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@ def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argpa
5757
parser.add_argument(
5858
"--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used."
5959
)
60+
parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server")
6061
return parser

riva/client/auth.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,26 @@
44
import os
55
from pathlib import Path
66
from typing import List, Optional, Tuple, Union
7-
87
import grpc
98

109

1110
def create_channel(
12-
ssl_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051",
11+
ssl_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None,
1312
) -> grpc.Channel:
13+
14+
def metadata_callback(context, callback):
15+
callback(metadata, None)
16+
1417
if ssl_cert is not None or use_ssl:
1518
root_certificates = None
1619
if ssl_cert is not None:
1720
ssl_cert = Path(ssl_cert).expanduser()
1821
with open(ssl_cert, 'rb') as f:
1922
root_certificates = f.read()
2023
creds = grpc.ssl_channel_credentials(root_certificates)
24+
if metadata:
25+
auth_creds = grpc.metadata_call_credentials(metadata_callback)
26+
creds = grpc.composite_channel_credentials(creds, auth_creds)
2127
channel = grpc.secure_channel(uri, creds)
2228
else:
2329
channel = grpc.insecure_channel(uri)
@@ -30,6 +36,7 @@ def __init__(
3036
ssl_cert: Optional[Union[str, os.PathLike]] = None,
3137
use_ssl: bool = False,
3238
uri: str = "localhost:50051",
39+
metadata_args: List[List[str]] = None,
3340
) -> None:
3441
"""
3542
A class responsible for establishing connection with a server and providing security metadata.
@@ -44,7 +51,13 @@ def __init__(
4451
self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser()
4552
self.uri: str = uri
4653
self.use_ssl: bool = use_ssl
47-
self.channel: grpc.Channel = create_channel(self.ssl_cert, self.use_ssl, self.uri)
54+
self.metadata = []
55+
if metadata_args:
56+
for meta in metadata_args:
57+
if len(meta) != 2:
58+
raise ValueError(f"Metadata should have 2 parameters in \"key\" \"value\" pair. Receieved {len(meta)} parameters.")
59+
self.metadata.append(tuple(meta))
60+
self.channel: grpc.Channel = create_channel(self.ssl_cert, self.use_ssl, self.uri, self.metadata)
4861

4962
def get_auth_metadata(self) -> List[Tuple[str, str]]:
5063
"""

scripts/asr/riva_streaming_asr_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def streaming_transcription_worker(
5050
) -> None:
5151
output_file = Path(output_file).expanduser()
5252
try:
53-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
53+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
5454
asr_service = riva.client.ASRService(auth)
5555
config = riva.client.StreamingRecognitionConfig(
5656
config=riva.client.RecognitionConfig(

scripts/asr/transcribe_file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main() -> None:
6666
if args.list_devices:
6767
riva.client.audio_io.list_output_devices()
6868
return
69-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
69+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
7070
asr_service = riva.client.ASRService(auth)
7171
config = riva.client.StreamingRecognitionConfig(
7272
config=riva.client.RecognitionConfig(

scripts/asr/transcribe_file_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def parse_args() -> argparse.Namespace:
2626

2727
def main() -> None:
2828
args = parse_args()
29-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
29+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
3030
asr_service = riva.client.ASRService(auth)
3131
config = riva.client.RecognitionConfig(
3232
language_code=args.language_code,

scripts/asr/transcribe_mic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main() -> None:
4141
if args.list_devices:
4242
riva.client.audio_io.list_input_devices()
4343
return
44-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
44+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
4545
asr_service = riva.client.ASRService(auth)
4646
config = riva.client.StreamingRecognitionConfig(
4747
config=riva.client.RecognitionConfig(

scripts/nlp/eval_intent_slot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def parse_args() -> argparse.Namespace:
340340

341341
def main() -> None:
342342
args = parse_args()
343-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
343+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
344344
service = riva.client.NLPService(auth)
345345
intent_report, slot_report = intent_slots_classification_report(
346346
args.input_file,

scripts/nlp/intentslot_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def pretty_print_result(
4444

4545
def main() -> None:
4646
args = parse_args()
47-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
47+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
4848
service = riva.client.NLPService(auth)
4949
if args.interactive:
5050
while True:

scripts/nlp/ner_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def parse_args() -> argparse.Namespace:
3131

3232
def main() -> None:
3333
args = parse_args()
34-
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
34+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
3535
service = riva.client.NLPService(auth)
3636
tokens, slots, slot_confidences, starts, ends = riva.client.extract_most_probable_token_classification_predictions(
3737
service.classify_tokens(input_strings=args.query, model_name=args.model)

0 commit comments

Comments
 (0)