diff --git a/riva-ts-client/.eslintrc.js b/riva-ts-client/.eslintrc.js new file mode 100644 index 00000000..3f99706c --- /dev/null +++ b/riva-ts-client/.eslintrc.js @@ -0,0 +1,15 @@ +module.exports = { + parser: '@typescript-eslint/parser', + extends: [ + 'plugin:@typescript-eslint/recommended' + ], + parserOptions: { + ecmaVersion: 2020, + sourceType: 'module' + }, + rules: { + '@typescript-eslint/explicit-function-return-type': 'warn', + '@typescript-eslint/no-explicit-any': 'warn', + '@typescript-eslint/no-unused-vars': ['error', { 'argsIgnorePattern': '^_' }] + } +}; diff --git a/riva-ts-client/README.md b/riva-ts-client/README.md new file mode 100644 index 00000000..b6ea4576 --- /dev/null +++ b/riva-ts-client/README.md @@ -0,0 +1,197 @@ +# NVIDIA Riva TypeScript Client + +TypeScript implementation of the NVIDIA Riva client, providing a modern, type-safe interface for interacting with NVIDIA Riva services. This client is designed to be fully compatible with the Python implementation while leveraging TypeScript's type system for enhanced developer experience. + +## Features + +### Automatic Speech Recognition (ASR) +- Real-time streaming transcription with configurable chunk sizes +- Offline transcription with full audio files +- Word boosting and custom vocabulary +- Speaker diarization with configurable speaker count +- Custom endpointing configuration +- Model selection and listing +- Multi-language support +- WAV file handling and audio format utilities + +### Text-to-Speech (TTS) +- High-quality speech synthesis +- Streaming and offline synthesis modes +- Custom dictionary support +- Multi-voice and multi-language support +- SSML support +- Audio format conversion utilities +- WAV file output handling + +### Natural Language Processing (NLP) +- Text classification with confidence scores +- Token classification with position information +- Entity analysis with type and score +- Intent recognition with slot filling +- Text transformation +- Natural language query processing +- Language code support + +### Neural Machine Translation (NMT) +- Text-to-text translation +- Language pair configuration +- Batch translation support + +## Prerequisites + +- Node.js (v18.x or later) +- npm (v6.x or later) +- Protocol Buffers compiler (protoc) +- TypeScript (v5.x or later) + +## Installation + +```bash +npm install nvidia-riva-client +``` + +## Building from Source + +```bash +git clone https://github.com/nvidia-riva/python-clients +cd python-clients/riva-ts-client +npm install +npm run build +``` + +## Quick Start + +### ASR Example +```typescript +import { ASRService } from 'nvidia-riva-client'; + +const asr = new ASRService({ + serverUrl: 'localhost:50051' +}); + +// Streaming recognition +async function streamingExample() { + const config = { + encoding: AudioEncoding.LINEAR_PCL_16, + sampleRateHz: 16000, + languageCode: 'en-US', + audioChannelCount: 1 + }; + + for await (const response of asr.streamingRecognize(audioSource, config)) { + console.log(response.results[0]?.alternatives[0]?.transcript); + } +} + +// Offline recognition +async function offlineExample() { + const config = { + encoding: AudioEncoding.LINEAR_PCL_16, + sampleRateHz: 16000, + languageCode: 'en-US', + audioChannelCount: 1, + enableSpeakerDiarization: true, + maxSpeakers: 2 + }; + + const response = await asr.recognize(audioBuffer, config); + console.log(response.results[0]?.alternatives[0]?.transcript); +} +``` + +### TTS Example +```typescript +import { SpeechSynthesisService } from 'nvidia-riva-client'; + +const tts = new SpeechSynthesisService({ + serverUrl: 'localhost:50051' +}); + +async function synthesizeExample() { + const response = await tts.synthesize('Hello, welcome to Riva!', { + language: 'en-US', + voice: 'English-US-Female-1', + sampleRateHz: 44100, + customDictionary: { + 'Riva': 'R IY V AH' + } + }); + + // Save to WAV file + await response.writeToFile('output.wav'); +} +``` + +### NLP Example +```typescript +import { NLPService } from 'nvidia-riva-client'; + +const nlp = new NLPService({ + serverUrl: 'localhost:50051' +}); + +async function nlpExample() { + // Text Classification + const classifyResult = await nlp.classifyText( + 'Great product, highly recommend!', + 'sentiment', + 'en-US' + ); + console.log(classifyResult.results[0]?.label); + + // Entity Analysis + const entityResult = await nlp.analyzeEntities( + 'NVIDIA is headquartered in Santa Clara, California.' + ); + console.log(entityResult.entities); + + // Intent Recognition + const intentResult = await nlp.analyzeIntent( + 'What is the weather like today?' + ); + console.log(intentResult.intent, intentResult.confidence); +} +``` + +### NMT Example +```typescript +import { NMTService } from 'nvidia-riva-client'; + +const nmt = new NMTService({ + serverUrl: 'localhost:50051' +}); + +async function translateExample() { + const result = await nmt.translate( + 'Hello, how are you?', + 'en-US', + 'es-ES' + ); + console.log(result.translations[0]?.text); +} +``` + +## API Documentation + +For detailed API documentation, please refer to the [API Reference](docs/api.md). + +## Testing + +```bash +# Run all tests +npm test + +# Run tests with coverage +npm run test:coverage + +# Run tests in watch mode +npm run test:watch +``` + +## Contributing + +We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details. + +## License + +This project is licensed under the terms of the [Apache 2.0 License](LICENSE). diff --git a/riva-ts-client/package.json b/riva-ts-client/package.json new file mode 100644 index 00000000..1ca5e800 --- /dev/null +++ b/riva-ts-client/package.json @@ -0,0 +1,83 @@ +{ + "name": "nvidia-riva-client", + "version": "2.18.0-rc0", + "description": "TypeScript implementation of the Riva Client API", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "build": "tsc", + "test": "vitest run", + "test:watch": "vitest", + "test:coverage": "vitest run --coverage", + "proto:generate": "ts-node scripts/generate-protos.ts", + "lint": "eslint . --ext .ts", + "format": "prettier --write \"src/**/*.ts\"", + "clean": "rimraf dist", + "prebuild": "npm run clean", + "prepare": "npm run build", + "tts:talk": "ts-node scripts/tts/talk.ts" + }, + "dependencies": { + "@grpc/grpc-js": "^1.8.0", + "@grpc/proto-loader": "^0.7.10", + "commander": "^9.4.1", + "google-protobuf": "^3.21.2", + "mic": "^2.1.2", + "node-wav": "^0.0.2", + "node-wav-player": "^0.2.0", + "pino": "^8.17.2", + "rxjs": "^7.8.1", + "wavefile": "^11.0.0", + "winston": "^3.11.0" + }, + "devDependencies": { + "@eslint/eslintrc": "^3.0.0", + "@types/google-protobuf": "^3.15.12", + "@types/jest": "^29.5.11", + "@types/node": "^20.11.5", + "@types/node-wav": "^0.0.2", + "@typescript-eslint/eslint-plugin": "^6.19.0", + "@typescript-eslint/parser": "^6.19.0", + "@vitest/coverage-v8": "^1.6.0", + "eslint": "^8.56.0", + "jest": "^29.7.0", + "prettier": "^3.2.4", + "protoc-gen-ts": "^0.8.7", + "rimraf": "^5.0.5", + "ts-jest": "^29.1.1", + "ts-node": "^10.9.2", + "ts-proto": "^1.181.2", + "typescript": "^5.3.3", + "vitest": "^1.6.0" + }, + "engines": { + "node": ">=18.0.0" + }, + "keywords": [ + "deep learning", + "machine learning", + "gpu", + "NLP", + "ASR", + "TTS", + "NMT", + "nvidia", + "speech", + "language", + "Riva", + "client" + ], + "author": { + "name": "Anton Peganov", + "email": "apeganov@nvidia.com" + }, + "repository": { + "type": "git", + "url": "https://github.com/nvidia-riva/python-clients" + }, + "homepage": "https://github.com/nvidia-riva/python-clients", + "bugs": { + "url": "https://github.com/nvidia-riva/python-clients/issues" + }, + "license": "MIT" +} diff --git a/riva-ts-client/proto/riva_asr.proto b/riva-ts-client/proto/riva_asr.proto new file mode 100644 index 00000000..65d7450d --- /dev/null +++ b/riva-ts-client/proto/riva_asr.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package nvidia.riva; + +import "riva_services.proto"; + +service RivaSpeechRecognition { + rpc Recognize(RecognizeRequest) returns (RecognizeResponse); + rpc StreamingRecognize(stream StreamingRecognizeRequest) returns (stream StreamingRecognizeResponse); +} + +message RecognizeRequest { + AudioConfig config = 1; + bytes audio = 2; + string model = 3; +} + +message RecognizeResponse { + message Result { + string transcript = 1; + float confidence = 2; + repeated WordInfo words = 3; + } + repeated Result results = 1; +} + +message StreamingRecognizeRequest { + oneof streaming_request { + AudioConfig config = 1; + bytes audio_content = 2; + } +} + +message StreamingRecognizeResponse { + message Result { + string transcript = 1; + float confidence = 2; + bool is_final = 3; + repeated WordInfo words = 4; + } + repeated Result results = 1; +} + +message WordInfo { + string word = 1; + float start_time = 2; + float end_time = 3; + float confidence = 4; +} diff --git a/riva-ts-client/proto/riva_nlp.proto b/riva-ts-client/proto/riva_nlp.proto new file mode 100644 index 00000000..0bec67c2 --- /dev/null +++ b/riva-ts-client/proto/riva_nlp.proto @@ -0,0 +1,102 @@ +syntax = "proto3"; + +package nvidia.riva; + +service RivaLanguageUnderstanding { + rpc ClassifyText(ClassifyRequest) returns (ClassifyResponse); + rpc ClassifyTokens(TokenClassifyRequest) returns (TokenClassifyResponse); + rpc AnalyzeEntities(AnalyzeEntitiesRequest) returns (AnalyzeEntitiesResponse); + rpc AnalyzeIntent(AnalyzeIntentRequest) returns (AnalyzeIntentResponse); + rpc TransformText(TransformTextRequest) returns (TransformTextResponse); + rpc PunctuateText(TransformTextRequest) returns (TransformTextResponse); + rpc NaturalQuery(NaturalQueryRequest) returns (NaturalQueryResponse); +} + +message ClassifyRequest { + repeated string text = 1; + message Model { + string model_name = 1; + string language_code = 2; + } + Model model = 2; +} + +message ClassifyResponse { + message Result { + string label = 1; + float score = 2; + } + repeated Result results = 1; +} + +message TokenClassifyRequest { + repeated string text = 1; + message Model { + string model_name = 1; + string language_code = 2; + } + Model model = 2; +} + +message TokenClassifyResponse { + message Token { + string text = 1; + string label = 2; + float score = 3; + int32 start = 4; + int32 end = 5; + } + message Result { + repeated Token tokens = 1; + } + repeated Result results = 1; +} + +message AnalyzeEntitiesRequest { + string text = 1; +} + +message AnalyzeEntitiesResponse { + message Entity { + string text = 1; + string type = 2; + float score = 3; + int32 start = 4; + int32 end = 5; + } + repeated Entity entities = 1; +} + +message AnalyzeIntentRequest { + string text = 1; +} + +message AnalyzeIntentResponse { + string intent = 1; + float confidence = 2; + message Slot { + string text = 1; + string type = 2; + float score = 3; + } + repeated Slot slots = 3; +} + +message TransformTextRequest { + string text = 1; + string model = 2; +} + +message TransformTextResponse { + string text = 1; +} + +message NaturalQueryRequest { + string query = 1; + string context = 2; +} + +message NaturalQueryResponse { + string response = 1; + float confidence = 2; +} diff --git a/riva-ts-client/proto/riva_nmt.proto b/riva-ts-client/proto/riva_nmt.proto new file mode 100644 index 00000000..262277c1 --- /dev/null +++ b/riva-ts-client/proto/riva_nmt.proto @@ -0,0 +1,69 @@ +syntax = "proto3"; + +package nvidia.riva; + +import "riva_services.proto"; + +service RivaNMTService { + rpc TranslateText(TranslateTextRequest) returns (TranslateTextResponse); + rpc StreamingTranslateSpeechToSpeech(stream StreamingS2SRequest) returns (stream StreamingS2SResponse); + rpc StreamingTranslateSpeechToText(stream StreamingS2TRequest) returns (stream StreamingS2TResponse); + rpc ListSupportedLanguagePairs(AvailableLanguageRequest) returns (AvailableLanguageResponse); +} + +message TranslateTextRequest { + string text = 1; + string source_language = 2; + string target_language = 3; + repeated string do_not_translate_phrases = 4; +} + +message TranslateTextResponse { + string text = 1; + repeated string translations = 2; +} + +message StreamingS2SRequest { + oneof streaming_request { + AudioConfig config = 1; + bytes audio_content = 2; + } +} + +message StreamingS2SResponse { + message Result { + string transcript = 1; + string translation = 2; + bool is_partial = 3; + bytes audio_content = 4; + } + Result result = 1; +} + +message StreamingS2TRequest { + oneof streaming_request { + AudioConfig config = 1; + bytes audio_content = 2; + } +} + +message StreamingS2TResponse { + message Result { + string transcript = 1; + string translation = 2; + bool is_partial = 3; + } + Result result = 1; +} + +message AvailableLanguageRequest { + string model = 1; +} + +message AvailableLanguageResponse { + message LanguagePair { + string source_language_code = 1; + string target_language_code = 2; + } + repeated LanguagePair supported_language_pairs = 1; +} diff --git a/riva-ts-client/proto/riva_services.proto b/riva-ts-client/proto/riva_services.proto new file mode 100644 index 00000000..82ca368a --- /dev/null +++ b/riva-ts-client/proto/riva_services.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package nvidia.riva; + +// Service definitions +service RivaSpeechSynthesis { + rpc Synthesize(SynthesizeRequest) returns (SynthesizeResponse); + rpc SynthesizeOnline(SynthesizeRequest) returns (stream SynthesizeResponse); + rpc GetRivaSynthesisConfig(GetRivaSynthesisConfigRequest) returns (GetRivaSynthesisConfigResponse); +} + +// Message types for Speech Synthesis +message SynthesizeRequest { + string text = 1; + string language_code = 2; + int32 sample_rate_hz = 3; + AudioEncoding encoding = 4; + string voice_name = 5; + optional string custom_dictionary = 6; +} + +message SynthesizeResponse { + bytes audio = 1; + AudioConfig audio_config = 2; +} + +message GetRivaSynthesisConfigRequest {} + +message GetRivaSynthesisConfigResponse { + message ModelConfig { + message Parameters { + string language_code = 1; + string voice_name = 2; + string subvoices = 3; + } + Parameters parameters = 1; + } + repeated ModelConfig model_config = 1; +} + +// Common types +message AudioConfig { + AudioEncoding encoding = 1; + int32 sample_rate_hz = 2; + string language_code = 3; + bool enable_word_time_offsets = 4; + int32 channels = 5; +} + +enum AudioEncoding { + ENCODING_UNSPECIFIED = 0; + LINEAR_PCM = 1; + FLAC = 2; + MULAW = 3; + ALAW = 4; +} diff --git a/riva-ts-client/scripts/asr/riva-streaming-asr-client.ts b/riva-ts-client/scripts/asr/riva-streaming-asr-client.ts new file mode 100644 index 00000000..62970d98 --- /dev/null +++ b/riva-ts-client/scripts/asr/riva-streaming-asr-client.ts @@ -0,0 +1,166 @@ + + +import * as fs from 'fs'; +import * as path from 'path'; +import { program } from 'commander'; +import { Auth, ASRService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { addAsrConfigArgparseParameters } from '../utils/asr_argparse'; +import { AudioEncoding, RecognitionConfig } from '../../src/client/asr/types'; +import { getWavFileParameters } from '../../src/client/asr/utils'; + +interface StreamingTranscriptionOptions { + inputFile: string; + server: string; + useSsl: boolean; + sslCert?: string; + metadata?: string[]; + maxAlternatives?: number; + profanityFilter?: boolean; + wordTimeOffsets?: boolean; + automaticPunctuation?: boolean; + noVerbatimTranscripts?: boolean; + speakerDiarization?: boolean; + diarizationMaxSpeakers?: string; + boostedLmWords?: string[]; + boostedLmScore?: string; + startHistory?: string; + startThreshold?: string; + stopHistory?: string; + stopHistoryEou?: string; + stopThreshold?: string; + stopThresholdEou?: string; +} + +async function streamingTranscriptionWorker(options: StreamingTranscriptionOptions, threadId: number) { + const outputFile = path.join(process.cwd(), `output_${threadId}.txt`); + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const asr = new ASRService({ + auth, + serverUrl: options.server + }); + + const { encoding, sampleRate } = await getWavFileParameters(options.inputFile); + if (encoding !== AudioEncoding.LINEAR_PCM) { + throw new Error('Only LINEAR_PCM WAV files are supported'); + } + + const audioContent = fs.readFileSync(options.inputFile); + + try { + const config: RecognitionConfig = { + encoding, + sampleRateHertz: sampleRate, + languageCode: 'en-US', + maxAlternatives: options.maxAlternatives, + profanityFilter: options.profanityFilter, + enableWordTimeOffsets: options.wordTimeOffsets, + enableAutomaticPunctuation: options.automaticPunctuation, + enableSpeakerDiarization: options.speakerDiarization, + diarizationConfig: options.speakerDiarization ? { + enableSpeakerDiarization: true, + maxSpeakerCount: options.diarizationMaxSpeakers ? parseInt(options.diarizationMaxSpeakers) : undefined + } : undefined, + speechContexts: options.boostedLmWords ? [{ + phrases: options.boostedLmWords, + boost: options.boostedLmScore ? parseFloat(options.boostedLmScore) : 1.0 + }] : undefined, + endpointingConfig: { + startHistory: options.startHistory ? parseInt(options.startHistory) : undefined, + startThreshold: options.startThreshold ? parseFloat(options.startThreshold) : undefined, + stopHistory: options.stopHistory ? parseInt(options.stopHistory) : undefined, + stopHistoryEou: options.stopHistoryEou ? parseInt(options.stopHistoryEou) : undefined, + stopThreshold: options.stopThreshold ? parseFloat(options.stopThreshold) : undefined, + stopThresholdEou: options.stopThresholdEou ? parseFloat(options.stopThresholdEou) : undefined + } + }; + + const responses = await asr.streamingRecognize( + { content: audioContent }, + { config } + ); + + const outputStream = fs.createWriteStream(outputFile); + + for await (const response of responses) { + if (response.results.length === 0) continue; + + for (const result of response.results) { + if (result.alternatives.length === 0) continue; + + const transcript = result.alternatives[0].transcript; + if (result.isPartial) { + process.stdout.write(`\rIntermediate transcript: ${transcript}`); + } else { + console.log(`\nFinal transcript: ${transcript}`); + outputStream.write(`${transcript}\n`); + } + + if (options.wordTimeOffsets && !result.isPartial) { + console.log('\nWord timings:'); + for (const word of result.alternatives[0].words || []) { + const start = word.startTime || 0; + const end = word.endTime || 0; + console.log(` ${word.word}: ${start}s - ${end}s`); + } + } + } + } + + outputStream.end(); + } catch (error) { + console.error(`Thread ${threadId} error:`, error); + throw error; + } +} + +async function main() { + program + .description('Streaming transcription via Riva AI Services') + .requiredOption('--input-file ', 'A path to a local file to transcribe.') + .option('--num-parallel ', 'Number of parallel transcription threads.', '1') + .option('--boosted-lm-words ', 'List of words to boost when decoding.') + .option('--boosted-lm-score ', 'Score by which to boost the boosted words.', '4.0') + .option('--speaker-diarization', 'Enable speaker diarization.') + .option('--diarization-max-speakers ', 'Maximum number of speakers to identify.', '6') + .option('--start-history ', 'Number of frames to use for start threshold.', '30') + .option('--start-threshold ', 'Threshold for starting audio.', '0.0') + .option('--stop-history ', 'Number of frames to use for stop threshold.', '30') + .option('--stop-history-eou ', 'Number of frames to use for end-of-utterance detection.', '30') + .option('--stop-threshold ', 'Threshold for stopping audio.', '0.0') + .option('--stop-threshold-eou ', 'Threshold for end-of-utterance detection.', '0.0'); + + addConnectionArgparseParameters(program); + addAsrConfigArgparseParameters(program); + + program.parse(); + + const options = program.opts(); + const numParallel = parseInt(options.numParallel); + + const promises = Array.from({ length: numParallel }, (_, i) => + streamingTranscriptionWorker(options as StreamingTranscriptionOptions, i + 1) + ); + + try { + await Promise.all(promises); + console.log('\nAll transcription threads completed successfully'); + } catch (error) { + console.error('Error in transcription:', error); + process.exit(1); + } +} + +if (require.main === module) { + main().catch(console.error); +} diff --git a/riva-ts-client/scripts/asr/transcribe-file-offline.ts b/riva-ts-client/scripts/asr/transcribe-file-offline.ts new file mode 100644 index 00000000..81e93f1c --- /dev/null +++ b/riva-ts-client/scripts/asr/transcribe-file-offline.ts @@ -0,0 +1,117 @@ + + +import * as fs from 'fs'; +import * as path from 'path'; +import { program } from 'commander'; +import { Auth, ASRService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { addAsrConfigArgparseParameters } from '../utils/asr_argparse'; +import { AudioEncoding, RecognitionConfig } from '../../src/client/asr/types'; +import { getWavFileParameters } from '../../src/client/asr/utils'; + +async function main() { + program + .description('Offline file transcription via Riva AI Services') + .requiredOption('--input-file ', 'A path to a local file to transcribe.') + .option('--boosted-lm-words ', 'List of words to boost when decoding.') + .option('--boosted-lm-score ', 'Score by which to boost the boosted words.', '4.0') + .option('--speaker-diarization', 'Enable speaker diarization.') + .option('--diarization-max-speakers ', 'Maximum number of speakers to identify.', '6') + .option('--start-history ', 'Number of frames to use for start threshold.', '30') + .option('--start-threshold ', 'Threshold for starting audio.', '0.0') + .option('--stop-history ', 'Number of frames to use for stop threshold.', '30') + .option('--stop-history-eou ', 'Number of frames to use for end-of-utterance detection.', '30') + .option('--stop-threshold ', 'Threshold for stopping audio.', '0.0') + .option('--stop-threshold-eou ', 'Threshold for end-of-utterance detection.', '0.0'); + + addConnectionArgparseParameters(program); + addAsrConfigArgparseParameters(program); + + program.parse(); + + const options = program.opts(); + const inputFile = path.resolve(options.inputFile); + + if (!fs.existsSync(inputFile)) { + console.error(`Input file ${inputFile} does not exist`); + process.exit(1); + } + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const asr = new ASRService({ + serverUrl: options.server, + auth + }); + + const { encoding, sampleRate } = await getWavFileParameters(inputFile); + if (encoding !== AudioEncoding.LINEAR_PCM) { + console.error('Only LINEAR_PCM WAV files are supported'); + process.exit(1); + } + + try { + const audioContent = fs.readFileSync(inputFile); + + const config: RecognitionConfig = { + encoding, + sampleRateHertz: sampleRate, + languageCode: options.languageCode || 'en-US', + maxAlternatives: options.maxAlternatives ? parseInt(options.maxAlternatives) : 1, + profanityFilter: options.profanityFilter, + enableWordTimeOffsets: options.wordTimeOffsets, + enableAutomaticPunctuation: options.automaticPunctuation, + enableSpeakerDiarization: options.speakerDiarization, + diarizationConfig: options.speakerDiarization ? { + enableSpeakerDiarization: true, + maxSpeakerCount: options.diarizationMaxSpeakers ? parseInt(options.diarizationMaxSpeakers) : undefined + } : undefined, + speechContexts: options.boostedLmWords ? [{ + phrases: options.boostedLmWords, + boost: options.boostedLmScore ? parseFloat(options.boostedLmScore) : 1.0 + }] : undefined + }; + + const response = await asr.recognize(audioContent, config); + + for (const result of response.results) { + if (result.alternatives.length === 0) continue; + + const transcript = result.alternatives[0].transcript; + console.log(`\nTranscript: ${transcript}`); + + if (options.wordTimeOffsets) { + console.log('\nWord timings:'); + for (const word of result.alternatives[0].words || []) { + const start = word.startTime || 0; + const end = word.endTime || 0; + console.log(` ${word.word}: ${start}s - ${end}s`); + } + } + + if (options.speakerDiarization && result.alternatives[0].words) { + console.log('\nSpeaker diarization:'); + for (const word of result.alternatives[0].words) { + if (word.speakerTag) { + console.log(` Speaker ${word.speakerTag}: ${word.word}`); + } + } + } + } + } catch (error) { + console.error('Error in transcription:', error); + process.exit(1); + } +} + +if (require.main === module) { + main().catch(console.error); +} diff --git a/riva-ts-client/scripts/asr/transcribe-file.ts b/riva-ts-client/scripts/asr/transcribe-file.ts new file mode 100644 index 00000000..d14aef90 --- /dev/null +++ b/riva-ts-client/scripts/asr/transcribe-file.ts @@ -0,0 +1,175 @@ + + +import * as fs from 'fs'; +import * as path from 'path'; +import * as wavPlayer from 'node-wav-player'; +import { program } from 'commander'; +import { Auth, ASRService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { addAsrConfigArgparseParameters } from '../utils/asr_argparse'; +import { AudioEncoding, StreamingRecognitionConfig, AudioChunk } from '../../src/client/asr/types'; +import { getWavFileParameters } from '../../src/client/asr/utils'; + +class AudioPlayer { + private tempFile: string; + + constructor() { + this.tempFile = path.join(process.cwd(), '.temp_audio.wav'); + } + + async play(audioData: Buffer): Promise { + fs.writeFileSync(this.tempFile, audioData); + try { + await wavPlayer.play({ path: this.tempFile }); + } catch (error) { + console.error('Error playing audio:', error); + } + } + + close(): void { + if (fs.existsSync(this.tempFile)) { + fs.unlinkSync(this.tempFile); + } + } +} + +async function* createAudioSource(fileStream: fs.ReadStream): AsyncGenerator { + try { + for await (const chunk of fileStream) { + if (chunk instanceof Buffer && chunk.length > 0) { + yield { audioContent: chunk }; + } + } + } finally { + fileStream.destroy(); + } +} + +async function main() { + program + .description( + 'Streaming transcription of a file via Riva AI Services. Streaming means that audio is sent to a ' + + 'server in small chunks and transcripts are returned as soon as these transcripts are ready. ' + + 'You may play transcribed audio simultaneously with transcribing by setting --play-audio option.' + ) + .option('--input-file ', 'A path to a local file to stream.') + .option('--list-models', 'List available models.') + .option('--show-intermediate', 'Show intermediate transcripts as they are available.') + .option('--play-audio', 'Whether to play input audio simultaneously with transcribing.') + .option('--file-streaming-chunk ', 'A maximum number of frames in one chunk sent to server.', '1600') + .option( + '--simulate-realtime', + 'Option to simulate realtime transcription. Audio fragments are sent to a server at a pace that mimics normal speech.' + ); + + addConnectionArgparseParameters(program); + addAsrConfigArgparseParameters(program); + + program.parse(); + + const options = program.opts(); + + if (!options.inputFile && !options.listModels) { + console.error('Either --input-file or --list-models must be specified'); + process.exit(1); + } + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const asr = new ASRService({ + serverUrl: options.server, + auth + }); + + if (options.listModels) { + try { + const models = await asr.listModels(); + console.log('Available models:'); + for (const model of models) { + console.log(` ${model.name}`); + console.log(` Languages: ${model.languages.join(', ')}`); + console.log(` Sample Rate: ${model.sampleRate}Hz`); + console.log(` Streaming: ${model.streaming}`); + console.log(); + } + return; + } catch (error) { + console.error('Error listing models:', error); + process.exit(1); + } + } + + const inputFile = path.resolve(options.inputFile); + if (!fs.existsSync(inputFile)) { + console.error(`Input file ${inputFile} does not exist`); + process.exit(1); + } + + const { encoding, sampleRate } = await getWavFileParameters(inputFile); + if (encoding !== AudioEncoding.LINEAR_PCM) { + console.error('Only LINEAR_PCM WAV files are supported'); + process.exit(1); + } + + const fileStream = fs.createReadStream(inputFile); + let audioPlayer: AudioPlayer | null = null; + + if (options.playAudio) { + audioPlayer = new AudioPlayer(); + } + + try { + const config: StreamingRecognitionConfig = { + config: { + encoding, + sampleRateHertz: sampleRate, + languageCode: options.languageCode || 'en-US', + maxAlternatives: options.maxAlternatives ? parseInt(options.maxAlternatives) : 1, + profanityFilter: options.profanityFilter, + enableWordTimeOffsets: options.wordTimeOffsets, + enableAutomaticPunctuation: options.automaticPunctuation + } + }; + + const audioSource = createAudioSource(fileStream); + const responses = await asr.streamingRecognize(audioSource, config); + + for await (const response of responses) { + if (response.results.length > 0) { + const result = response.results[0]; + if (result.alternatives.length > 0) { + const transcript = result.alternatives[0].transcript; + if (!result.isPartial) { + console.log(`Final transcript: ${transcript}`); + } else if (options.showIntermediate) { + console.log(`Intermediate transcript: ${transcript}`); + } + } + } + + if (audioPlayer && response.audioContent) { + const audioBuffer = Buffer.from(response.audioContent); + await audioPlayer.play(audioBuffer); + } + } + } catch (error) { + console.error('Error in transcription:', error); + process.exit(1); + } finally { + if (audioPlayer) { + audioPlayer.close(); + } + } +} + +if (require.main === module) { + main().catch(console.error); +} diff --git a/riva-ts-client/scripts/asr/transcribe-mic.ts b/riva-ts-client/scripts/asr/transcribe-mic.ts new file mode 100644 index 00000000..1d3d4d5e --- /dev/null +++ b/riva-ts-client/scripts/asr/transcribe-mic.ts @@ -0,0 +1,122 @@ + + +import { program } from 'commander'; +import * as mic from 'mic'; +import { Auth, ASRService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { addAsrConfigArgparseParameters } from '../utils/asr_argparse'; +import { AudioEncoding, StreamingRecognitionConfig, AudioChunk } from '../../src/client/asr/types'; + +async function* createMicAudioSource(micInstance: mic.MicInstance): AsyncGenerator { + const audioStream = micInstance.getAudioStream(); + try { + for await (const chunk of audioStream) { + if (chunk instanceof Buffer && chunk.length > 0) { + yield { audioContent: chunk }; + } + } + } finally { + micInstance.stop(); + } +} + +async function main() { + program + .description('Streaming transcription from microphone via Riva AI Services') + .option('--list-models', 'List available models.') + .option('--show-intermediate', 'Show intermediate transcripts as they are available.') + .option('--device ', 'Input device to use.') + .option('--rate ', 'Input device sample rate.', '16000') + .option('--channels ', 'Number of input channels.', '1'); + + addConnectionArgparseParameters(program); + addAsrConfigArgparseParameters(program); + + program.parse(); + + const options = program.opts(); + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const asr = new ASRService({ + serverUrl: options.server, + auth + }); + + if (options.listModels) { + try { + const models = await asr.listModels(); + console.log('Available models:'); + for (const model of models) { + console.log(` ${model.name}`); + console.log(` Languages: ${model.languages.join(', ')}`); + console.log(` Sample Rate: ${model.sampleRate}Hz`); + console.log(` Streaming: ${model.streaming}`); + console.log(); + } + return; + } catch (error) { + console.error('Error listing models:', error); + process.exit(1); + } + } + + const micInstance = mic({ + rate: options.rate, + channels: options.channels, + debug: false, + device: options.device + }); + + const config: StreamingRecognitionConfig = { + config: { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHertz: parseInt(options.rate), + languageCode: options.languageCode || 'en-US', + maxAlternatives: options.maxAlternatives ? parseInt(options.maxAlternatives) : 1, + profanityFilter: options.profanityFilter, + enableWordTimeOffsets: options.wordTimeOffsets, + enableAutomaticPunctuation: options.automaticPunctuation + } + }; + + try { + const audioSource = createMicAudioSource(micInstance); + micInstance.start(); + + console.log('Listening... Press Ctrl+C to stop.'); + + const responses = await asr.streamingRecognize(audioSource, config); + + for await (const response of responses) { + if (response.results.length > 0) { + const result = response.results[0]; + if (result.alternatives.length > 0) { + const transcript = result.alternatives[0].transcript; + if (!result.isPartial) { + console.log(`Final transcript: ${transcript}`); + } else if (options.showIntermediate) { + console.log(`Intermediate transcript: ${transcript}`); + } + } + } + } + } catch (error) { + console.error('Error in transcription:', error); + process.exit(1); + } finally { + micInstance.stop(); + } +} + +if (require.main === module) { + main().catch(console.error); +} diff --git a/riva-ts-client/scripts/generate-protos.ts b/riva-ts-client/scripts/generate-protos.ts new file mode 100644 index 00000000..9f8529fc --- /dev/null +++ b/riva-ts-client/scripts/generate-protos.ts @@ -0,0 +1,18 @@ +import { execSync } from 'child_process'; +import * as path from 'path'; + +const PROTO_DIR = path.resolve(__dirname, '../proto'); +const OUT_DIR = path.resolve(__dirname, '../src/proto'); + +try { + execSync(`protoc --plugin=protoc-gen-ts_proto=./node_modules/.bin/protoc-gen-ts_proto \ + --ts_proto_out=${OUT_DIR} \ + --ts_proto_opt=esModuleInterop=true \ + --proto_path=${PROTO_DIR} \ + ${PROTO_DIR}/*.proto`); + + console.log('Protocol buffers generated successfully'); +} catch (error) { + console.error('Error generating protocol buffers:', error); + process.exit(1); +} diff --git a/riva-ts-client/scripts/nlp/punctuation-client.ts b/riva-ts-client/scripts/nlp/punctuation-client.ts new file mode 100644 index 00000000..e3492a6b --- /dev/null +++ b/riva-ts-client/scripts/nlp/punctuation-client.ts @@ -0,0 +1,178 @@ +import * as fs from 'fs'; +import { program } from 'commander'; +import { Auth, NLPService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import * as readline from 'readline'; + +interface TestCase { + input: string; + expected: string; +} + +const TEST_CASES: TestCase[] = [ + { + input: 'can you prove that you are self aware', + expected: 'Can you prove that you are self-aware?' + }, + { + input: 'hello how are you today', + expected: 'Hello, how are you today?' + }, + { + input: 'i like pizza pasta and ice cream', + expected: 'I like pizza, pasta, and ice cream.' + }, + { + input: 'what time is it', + expected: 'What time is it?' + }, + { + input: 'my name is john and i live in new york', + expected: 'My name is John and I live in New York.' + } +]; + +async function runPunctCapit( + nlpService: NLPService, + query: string, + modelName: string = 'punctuation', + languageCode: string = 'en-US' +): Promise { + const start = Date.now(); + try { + const response = await nlpService.punctuateText(query, modelName); + const result = response.text; + const timeTaken = (Date.now() - start) / 1000; + console.log(`Time taken: ${timeTaken.toFixed(3)}s`); + return result; + } catch (error) { + console.error('Error during punctuation:', error); + throw error; + } +} + +async function runTests(nlpService: NLPService, modelName?: string, languageCode: string = 'en-US'): Promise { + console.log('Running tests...\n'); + let passed = 0; + let failed = 0; + + for (const [index, testCase] of TEST_CASES.entries()) { + console.log(`Test ${index + 1}:`); + console.log(`Input: "${testCase.input}"`); + console.log(`Expected: "${testCase.expected}"`); + + try { + const result = await runPunctCapit(nlpService, testCase.input, modelName, languageCode); + console.log(`Got: "${result}"`); + + if (result === testCase.expected) { + console.log('✓ PASSED\n'); + passed++; + } else { + console.log('✗ FAILED\n'); + failed++; + } + } catch (error) { + console.log('✗ FAILED (error occurred)\n'); + failed++; + } + } + + console.log(`Summary: ${passed} passed, ${failed} failed`); +} + +async function interactive(nlpService: NLPService, modelName?: string, languageCode: string = 'en-US'): Promise { + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout + }); + + while (true) { + try { + const query = await new Promise((resolve) => { + rl.question('Enter a query (or Ctrl+C to exit): ', resolve); + }); + + const result = await runPunctCapit(nlpService, query, modelName, languageCode); + console.log(`Result: "${result}"\n`); + } catch (error) { + console.error('Error occurred:', error); + } + } +} + +async function main() { + program + .description('Client app to restore Punctuation and Capitalization with Riva') + .option( + '--model ', + 'Model on Riva Server to execute. If this parameter is missing, then the server will try to select a first available Punctuation & Capitalization model.' + ) + .option('--query ', 'Input Query', 'can you prove that you are self aware') + .option( + '--run-tests', + 'Flag to run sanity tests. If this option is chosen, then options --query and --interactive are ignored and a model is run on several hardcoded examples.' + ) + .option( + '--interactive', + 'If this option is set, then --query argument is ignored and the script suggests user to enter queries to standard input.' + ) + .option('--language-code ', 'Language code of the model to be used.', 'en-US') + .option('--input-file ', 'Input file with text to punctuate') + .option('--output-file ', 'Output file to write punctuated text'); + + addConnectionArgparseParameters(program); + + program.parse(); + + const options = program.opts(); + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const nlpService = new NLPService({ + serverUrl: options.server, + auth + }); + + try { + if (options.runTests) { + await runTests(nlpService, options.model, options.languageCode); + } else if (options.interactive) { + await interactive(nlpService, options.model, options.languageCode); + } else { + let text = options.query; + if (options.inputFile) { + text = fs.readFileSync(options.inputFile, 'utf-8'); + } + + const result = await runPunctCapit(nlpService, text, options.model, options.languageCode); + + if (options.outputFile) { + fs.writeFileSync(options.outputFile, result); + console.log(`Punctuated text written to ${options.outputFile}`); + } else { + console.log(`Result: "${result}"`); + } + } + } catch (error) { + console.error('Error:', error); + process.exit(1); + } +} + +if (require.main === module) { + main().catch(console.error); + + process.on('SIGINT', () => { + console.log('\nExiting...'); + process.exit(0); + }); +} diff --git a/riva-ts-client/scripts/nmt/nmt.ts b/riva-ts-client/scripts/nmt/nmt.ts new file mode 100644 index 00000000..afef668b --- /dev/null +++ b/riva-ts-client/scripts/nmt/nmt.ts @@ -0,0 +1,204 @@ + + +import * as fs from 'fs'; +import * as readline from 'readline'; +import { program } from 'commander'; +import { Auth, NeuralMachineTranslationService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { TranslateRequest, TranslateResponse } from '../../src/client/nmt/types'; + +interface DNTPhraseMapping { + phrase: string; + replacement?: string; +} + +function parseDNTPhrase(line: string): DNTPhraseMapping | null { + line = line.trim(); + if (!line) return null; + + const parts = line.split('##').map(p => p.trim()); + if (parts[0]) { + return { + phrase: parts[0], + replacement: parts[1] + }; + } + return null; +} + +function readDntPhrasesFile(filePath: string): string[] { + if (!filePath) return []; + + try { + const content = fs.readFileSync(filePath, 'utf-8'); + return content + .split('\n') + .map(parseDNTPhrase) + .filter((mapping): mapping is DNTPhraseMapping => mapping !== null) + .map(mapping => mapping.phrase); + } catch (error) { + console.error('Error reading DNT phrases file:', error); + return []; + } +} + +function formatTranslationResponse(response: TranslateResponse): string { + if (response.translations && response.translations.length > 0) { + const translation = response.translations[0]; + return `${translation.text} (confidence: ${translation.score.toFixed(2)})`; + } + return response.text; +} + +async function interactive( + nmtService: NeuralMachineTranslationService, + config: { + sourceLanguage: string; + targetLanguage: string; + model?: string; + doNotTranslatePhrases?: string[]; + } +): Promise { + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + terminal: true + }); + + console.log('\nEnter text to translate (press Ctrl+C to exit)\n'); + + while (true) { + try { + const text = await new Promise((resolve) => { + rl.question('> ', resolve); + }); + + if (!text.trim()) { + console.log('Please enter some text to translate.\n'); + continue; + } + + const request: TranslateRequest = { + text, + sourceLanguage: config.sourceLanguage, + targetLanguage: config.targetLanguage, + model: config.model, + doNotTranslatePhrases: config.doNotTranslatePhrases + }; + + const response = await nmtService.translate(request); + console.log(`Translation: ${formatTranslationResponse(response)}\n`); + } catch (error) { + console.error('Translation error:', error instanceof Error ? error.message : 'Unknown error'); + console.log('Please try again.\n'); + } + } +} + +async function translateSingle( + nmtService: NeuralMachineTranslationService, + config: { + text: string; + sourceLanguage: string; + targetLanguage: string; + model?: string; + doNotTranslatePhrases?: string[]; + } +): Promise { + try { + const request: TranslateRequest = { + text: config.text, + sourceLanguage: config.sourceLanguage, + targetLanguage: config.targetLanguage, + model: config.model, + doNotTranslatePhrases: config.doNotTranslatePhrases + }; + + const response = await nmtService.translate(request); + console.log(`Translation: ${formatTranslationResponse(response)}`); + } catch (error) { + console.error('Translation error:', error instanceof Error ? error.message : 'Unknown error'); + process.exit(1); + } +} + +async function listSupportedLanguages( + nmtService: NeuralMachineTranslationService, + model?: string +): Promise { + try { + const response = await nmtService.get_supported_language_pairs(model || ''); + console.log('\nSupported language pairs:'); + response.supportedLanguagePairs.forEach(pair => { + console.log(` ${pair.sourceLanguageCode} -> ${pair.targetLanguageCode}`); + }); + } catch (error) { + console.error('Error fetching supported languages:', error instanceof Error ? error.message : 'Unknown error'); + process.exit(1); + } +} + +async function main() { + program + .description('Neural Machine Translation (NMT) client for Riva') + .option('--source-language ', 'Source language code', 'en-US') + .option('--target-language ', 'Target language code', 'es-US') + .option('--model ', 'Model name to use for translation') + .option('--text ', 'Text to translate') + .option('--interactive', 'Enable interactive mode') + .option( + '--dnt-file ', + 'Path to file containing "do not translate" phrases. Each line should contain a phrase, ' + + 'optionally followed by ## and a replacement' + ) + .option('--list-languages', 'List supported language pairs for the specified model'); + + addConnectionArgparseParameters(program); + program.parse(); + + const options = program.opts(); + + if (!options.listLanguages && !options.interactive && !options.text) { + console.error('Either --text, --interactive, or --list-languages must be specified'); + process.exit(1); + } + + const nmtService = new NeuralMachineTranslationService({ + serverUrl: options.server, + auth: { + ssl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + } + }); + + const config = { + sourceLanguage: options.sourceLanguage, + targetLanguage: options.targetLanguage, + model: options.model, + doNotTranslatePhrases: options.dntFile ? readDntPhrasesFile(options.dntFile) : undefined + }; + + if (options.listLanguages) { + await listSupportedLanguages(nmtService, options.model); + } else if (options.interactive) { + await interactive(nmtService, config); + } else if (options.text) { + await translateSingle(nmtService, { ...config, text: options.text }); + } +} + +if (require.main === module) { + main().catch(error => { + console.error('Fatal error:', error instanceof Error ? error.message : 'Unknown error'); + process.exit(1); + }); + + process.on('SIGINT', () => { + console.log('\nExiting...'); + process.exit(0); + }); +} diff --git a/riva-ts-client/scripts/tts/talk.ts b/riva-ts-client/scripts/tts/talk.ts new file mode 100644 index 00000000..04bc0a44 --- /dev/null +++ b/riva-ts-client/scripts/tts/talk.ts @@ -0,0 +1,245 @@ + + +import * as fs from 'fs'; +import * as path from 'path'; +import * as wavPlayer from 'node-wav-player'; +import { program } from 'commander'; +import { Auth, SpeechSynthesisService } from '../../src/client'; +import { addConnectionArgparseParameters } from '../utils/argparse'; +import { AudioEncoding } from '../../src/client/asr/types'; + +interface CustomDictionary { + [key: string]: string; +} + +function readFileToDict(filePath: string): CustomDictionary { + const resultDict: CustomDictionary = {}; + const fileContent = fs.readFileSync(filePath, 'utf-8'); + const lines = fileContent.split('\n'); + + for (const [lineNumber, line] of lines.entries()) { + const trimmedLine = line.trim(); + if (!trimmedLine) continue; + + try { + const [key, value] = trimmedLine.split(/\s{2,}/); + if (key && value) { + resultDict[key.trim()] = value.trim(); + } else { + console.warn(`Warning: Malformed line ${lineNumber + 1}`); + } + } catch (error) { + console.warn(`Warning: Malformed line ${lineNumber + 1}`); + continue; + } + } + + if (Object.keys(resultDict).length === 0) { + throw new Error("No valid entries found in the file."); + } + + return resultDict; +} + +class AudioPlayer { + private tempFile: string; + + constructor() { + this.tempFile = path.join(process.cwd(), '.temp_audio.wav'); + } + + async write(audioData: Uint8Array): Promise { + fs.writeFileSync(this.tempFile, Buffer.from(audioData)); + try { + await wavPlayer.play({ path: this.tempFile }); + } catch (error) { + console.error('Error playing audio:', error); + } + } + + close(): void { + if (fs.existsSync(this.tempFile)) { + fs.unlinkSync(this.tempFile); + } + } +} + +async function main() { + program + .description('Speech synthesis via Riva AI Services') + .option('--text ', 'Text input to synthesize.') + .option('--list-devices', 'List output audio devices indices.') + .option('--list-voices', 'List available voices.') + .option('--voice ', 'A voice name to use. If this parameter is missing, then the server will try a first available model based on parameter `--language-code`.') + .option('--audio_prompt_file ', 'An input audio prompt (.wav) file for zero shot model. This is required to do zero shot inferencing.') + .option('-o, --output ', 'Output file .wav file to write synthesized audio.', 'output.wav') + .option('--quality ', 'Number of times decoder should be run on the output audio. A higher number improves quality but introduces latencies.', parseInt) + .option('--play-audio', 'Whether to play input audio simultaneously with transcribing.') + .option('--output-device ', 'Output device to use.', parseInt) + .option('--language-code ', 'A language of input text.', 'en-US') + .option('--sample-rate-hz ', 'Number of audio frames per second in synthesized audio.', parseInt, 44100) + .option('--custom-dictionary ', 'A file path to a user dictionary with key-value pairs separated by double spaces.') + .option('--stream', 'If set, streaming synthesis is applied. Audio is yielded as it gets ready. Otherwise, synthesized audio is returned in 1 response when all text is processed.'); + + addConnectionArgparseParameters(program); + program.parse(); + + const options = program.opts(); + const outputPath = path.resolve(options.output); + + if (fs.existsSync(outputPath) && fs.statSync(outputPath).isDirectory()) { + console.error("Empty output file path not allowed"); + return; + } + + if (options.listDevices) { + console.log("Audio devices listing is not supported in this version"); + return; + } + + const auth = new Auth({ + uri: options.server, + useSsl: options.useSsl, + sslCert: options.sslCert, + credentials: options.credentials, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + }); + + const service = new SpeechSynthesisService({ + serverUrl: options.server, + auth: { + ssl: options.useSsl, + sslCert: options.sslCert, + metadata: options.metadata?.map(m => { + const [key, value] = m.split('='); + return [key, value] as [string, string]; + }) + } + }); + + let soundStream: AudioPlayer | null = null; + let outFile: fs.WriteStream | null = null; + + if (options.listVoices) { + try { + const configResponse = await service.getRivaSynthesisConfig(); + const ttsModels: { [key: string]: { voices: string[] } } = {}; + + for (const modelConfig of configResponse.modelConfig) { + const languageCode = modelConfig.parameters.languageCode; + const voiceName = modelConfig.parameters.voiceName; + const subvoices = modelConfig.parameters.subvoices + .split(',') + .map((voice: string) => voice.split(':')[0]); + const fullVoiceNames = subvoices.map((subvoice: string) => `${voiceName}.${subvoice}`); + + if (languageCode in ttsModels) { + ttsModels[languageCode].voices.push(...fullVoiceNames); + } else { + ttsModels[languageCode] = { voices: fullVoiceNames }; + } + } + + console.log(JSON.stringify(ttsModels, null, 4)); + return; + } catch (error) { + console.error('Error getting voices:', error); + return; + } + } + + if (!options.text) { + console.error("No input text provided"); + return; + } + + try { + if (options.outputDevice !== undefined || options.playAudio) { + soundStream = new AudioPlayer(); + } + + if (outputPath) { + outFile = fs.createWriteStream(outputPath); + } + + let customDictionaryInput: CustomDictionary = {}; + if (options.customDictionary) { + customDictionaryInput = readFileToDict(options.customDictionary); + } + + console.log("Generating audio for request..."); + const start = Date.now(); + + if (options.stream) { + const responses = service.synthesizeOnline( + options.text, + options.voice, + options.languageCode, + AudioEncoding.LINEAR_PCM, + options.sampleRateHz, + options.audioPromptFile, + AudioEncoding.LINEAR_PCM, + options.quality ?? 20, + customDictionaryInput + ); + + let first = true; + for await (const resp of responses) { + const stop = Date.now(); + if (first) { + console.log(`Time to first audio: ${(stop - start) / 1000}s`); + first = false; + } + if (soundStream) { + await soundStream.write(resp.audio); + } + if (outFile) { + outFile.write(Buffer.from(resp.audio)); + } + } + } else { + const resp = await service.synthesize( + options.text, + options.voice, + options.languageCode, + AudioEncoding.LINEAR_PCM, + options.sampleRateHz, + options.audioPromptFile, + AudioEncoding.LINEAR_PCM, + options.quality ?? 20, + false, + customDictionaryInput + ); + + const stop = Date.now(); + console.log(`Time spent: ${(stop - start) / 1000}s`); + + if (soundStream) { + await soundStream.write(resp.audio); + } + if (outFile) { + outFile.write(Buffer.from(resp.audio)); + } + } + } catch (error) { + if (error instanceof Error) { + console.error(error.message); + } else { + console.error('An unknown error occurred'); + } + } finally { + if (outFile) { + outFile.end(); + } + if (soundStream) { + soundStream.close(); + } + } +} + +if (require.main === module) { + main().catch(console.error); +} diff --git a/riva-ts-client/scripts/utils/argparse.ts b/riva-ts-client/scripts/utils/argparse.ts new file mode 100644 index 00000000..0291e277 --- /dev/null +++ b/riva-ts-client/scripts/utils/argparse.ts @@ -0,0 +1,9 @@ +import { Command } from 'commander'; + +export function addConnectionArgparseParameters(program: Command): Command { + return program + .option('--ssl-cert ', 'Path to SSL certificate file.') + .option('--use-ssl', 'Use SSL/TLS connection to the server.') + .option('--server
', 'Server address.', 'localhost:50051') + .option('--metadata ', 'Metadata to pass to gRPC. Example: key1=val1 key2=val2'); +} diff --git a/riva-ts-client/scripts/utils/asr_argparse.ts b/riva-ts-client/scripts/utils/asr_argparse.ts new file mode 100644 index 00000000..12838538 --- /dev/null +++ b/riva-ts-client/scripts/utils/asr_argparse.ts @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +import { Command } from 'commander'; + +export function addAsrConfigArgparseParameters(program: Command): void { + program + .option('--max-alternatives ', 'Maximum number of alternative transcripts to return.', '1') + .option('--profanity-filter', 'Enable profanity filtering.') + .option('--word-time-offsets', 'Enable word time offset information.') + .option('--automatic-punctuation', 'Enable automatic punctuation.') + .option('--no-verbatim-transcripts', 'Disable verbatim transcripts.') + .option('--metadata ', 'Metadata to send to server.') + .option('--language-code ', 'Language code for the request.', 'en-US'); +} diff --git a/riva-ts-client/src/client/asr/index.ts b/riva-ts-client/src/client/asr/index.ts new file mode 100644 index 00000000..dce71130 --- /dev/null +++ b/riva-ts-client/src/client/asr/index.ts @@ -0,0 +1,324 @@ +import * as grpc from '@grpc/grpc-js'; +import { BaseClient } from '../base'; +import { RivaError, handleGrpcError } from '../errors'; +import * as fs from 'fs'; +import * as path from 'path'; +import { + ASRServiceClient, + AudioChunk, + AudioSource, + AudioContentSource, + RecognitionConfig, + RecognizeResponse, + StreamingRecognitionConfig, + StreamingRecognizeResponse, + SpeechContext, + SpeakerDiarizationConfig, + EndpointingConfig, + WavFileParameters, + AudioEncoding, + ListModelsResponse +} from './types'; +import { getProtoClient } from '../utils/proto'; + +/** + * Get WAV file parameters + * @param filePath Path to WAV file + */ +export function getWavFileParameters(filePath: string): WavFileParameters | null { + try { + const buffer = fs.readFileSync(filePath); + if (buffer.toString('ascii', 0, 4) !== 'RIFF' || buffer.toString('ascii', 8, 12) !== 'WAVE') { + return null; + } + + // Parse WAV header + const sampleRate = buffer.readUInt32LE(24); + const numChannels = buffer.readUInt16LE(22); + const bitsPerSample = buffer.readUInt16LE(34); + const dataOffset = 44; // Standard WAV header size + const dataSize = buffer.readUInt32LE(40); + const numFrames = dataSize / (numChannels * (bitsPerSample / 8)); + + return { + nframes: numFrames, + framerate: sampleRate, + duration: numFrames / sampleRate, + nchannels: numChannels, + sampwidth: bitsPerSample / 8, + dataOffset + }; + } catch { + return null; + } +} + +/** + * Sleep for the duration of an audio chunk + * @param chunk Audio chunk + * @param duration Duration in seconds + */ +export function sleepAudioLength(chunk: Uint8Array, duration: number): Promise { + return new Promise(resolve => setTimeout(resolve, duration * 1000)); +} + +/** + * Iterator for audio chunks from a file + */ +export class AudioChunkFileIterator implements AsyncIterator { + private fileHandle: fs.promises.FileHandle | null = null; + private fileParameters: WavFileParameters | null = null; + private firstBuffer = true; + private closed = false; + private _delayCallback?: (chunk: Uint8Array, duration: number) => Promise; + + constructor( + private readonly filePath: string, + private readonly chunkFrames: number, + delayCallback?: (chunk: Uint8Array, duration: number) => Promise + ) { + this._delayCallback = delayCallback; + } + + async init(): Promise { + this.fileParameters = getWavFileParameters(this.filePath); + this.fileHandle = await fs.promises.open(this.filePath, 'r'); + + if (this._delayCallback && !this.fileParameters) { + console.warn('delay_callback not supported for encoding other than LINEAR_PCM'); + this._delayCallback = undefined; + } + } + + async next(): Promise> { + if (!this.fileHandle || this.closed) { + return { done: true, value: undefined }; + } + + if (!this.fileParameters) { + const chunk = Buffer.alloc(this.chunkFrames); + const { bytesRead } = await this.fileHandle.read(chunk, 0, this.chunkFrames); + if (bytesRead === 0) { + await this.close(); + return { done: true, value: undefined }; + } + return { done: false, value: { audioContent: chunk.slice(0, bytesRead) } }; + } + + const bytesToRead = this.chunkFrames * this.fileParameters.sampwidth * this.fileParameters.nchannels; + const chunk = Buffer.alloc(bytesToRead); + const { bytesRead } = await this.fileHandle.read(chunk, 0, bytesToRead); + + if (bytesRead === 0) { + await this.close(); + return { done: true, value: undefined }; + } + + if (this._delayCallback) { + const offset = this.firstBuffer ? this.fileParameters.dataOffset : 0; + await this._delayCallback( + chunk.slice(offset), + (bytesRead - offset) / this.fileParameters.sampwidth / this.fileParameters.framerate + ); + this.firstBuffer = false; + } + + return { + done: false, + value: { audioContent: chunk.slice(0, bytesRead) } + }; + } + + async close(): Promise { + if (this.fileHandle) { + await this.fileHandle.close(); + this.fileHandle = null; + } + this.closed = true; + } + + [Symbol.asyncIterator](): AsyncIterator { + return this; + } +} + +/** + * ASR Service for speech recognition + */ +export class ASRService extends BaseClient { + private readonly client: ASRServiceClient; + + constructor(config: { serverUrl: string; auth?: any }) { + super(config); + + const { RivaSpeechRecognitionClient } = getProtoClient('riva_asr'); + this.client = new RivaSpeechRecognitionClient( + config.serverUrl, + config.auth?.credentials || grpc.credentials.createInsecure() + ); + } + + private isContentSource(source: AudioSource): source is AudioContentSource { + return 'content' in source; + } + + private isAsyncIterable(source: AudioSource): source is AsyncIterable { + return Symbol.asyncIterator in source; + } + + private isIterable(source: AudioSource): source is Iterable { + return Symbol.iterator in source; + } + + /** + * Add word boosting to config + */ + public addWordBoosting( + config: RecognitionConfig | StreamingRecognitionConfig, + words: string[], + score: number + ): void { + const innerConfig = 'config' in config ? config.config : config; + if (words && words.length > 0) { + const context: SpeechContext = { + phrases: words, + boost: score + }; + innerConfig.speechContexts = innerConfig.speechContexts || []; + innerConfig.speechContexts.push(context); + } + } + + /** + * Add speaker diarization to config + */ + public addSpeakerDiarization( + config: RecognitionConfig, + enable: boolean, + maxSpeakers: number + ): void { + config.enableSpeakerDiarization = enable; + if (enable) { + config.diarizationConfig = { + enableSpeakerDiarization: true, + maxSpeakerCount: maxSpeakers + }; + } + } + + /** + * Add endpoint parameters to config + */ + public addEndpointParameters( + config: RecognitionConfig | StreamingRecognitionConfig, + endpointConfig: EndpointingConfig + ): void { + const innerConfig = 'config' in config ? config.config : config; + innerConfig.endpointingConfig = endpointConfig; + } + + /** + * Add audio file specs to config + */ + public addAudioFileSpecs( + config: RecognitionConfig | StreamingRecognitionConfig, + filePath: string + ): void { + const innerConfig = 'config' in config ? config.config : config; + const params = getWavFileParameters(filePath); + if (params) { + innerConfig.encoding = AudioEncoding.LINEAR_PCM; + innerConfig.sampleRateHertz = params.framerate; + innerConfig.audioChannelCount = params.nchannels; + } + } + + /** + * Add custom configuration to config + */ + public addCustomConfiguration( + config: RecognitionConfig | StreamingRecognitionConfig, + customConfig: string + ): void { + const innerConfig = 'config' in config ? config.config : config; + try { + const customConfigObj = JSON.parse(customConfig); + innerConfig.customConfiguration = customConfigObj; + } catch (error) { + console.warn('Failed to parse custom configuration:', error); + } + } + + /** + * Perform streaming recognition + */ + public async *streamingRecognize( + audioSource: AudioSource, + config: StreamingRecognitionConfig + ): AsyncGenerator { + const metadata = this.auth?.getCallMetadata(); + const stream = this.client.streamingRecognize(metadata); + + // Send config + stream.write({ streamingConfig: config }); + + // Send audio chunks + if (this.isContentSource(audioSource)) { + stream.write({ audioContent: audioSource.content }); + } else if (this.isAsyncIterable(audioSource)) { + for await (const chunk of audioSource) { + stream.write({ audioContent: chunk.audioContent }); + } + } else if (this.isIterable(audioSource)) { + for (const chunk of audioSource) { + stream.write({ audioContent: chunk.audioContent }); + } + } + + stream.end(); + + try { + for await (const response of stream) { + yield response; + } + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } + + /** + * Perform offline recognition + */ + public async recognize(audio: Uint8Array, config: RecognitionConfig): Promise { + try { + return await this.client.recognize({ config, audio: { content: audio } }); + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } + + /** + * List available ASR models + */ + public async listModels(): Promise> { + try { + const response = await this.client.listModels({}); + return response.models.map(model => ({ + name: model.name, + languages: model.languages, + sampleRate: model.sample_rate, + streaming: model.streaming_supported + })); + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } +} diff --git a/riva-ts-client/src/client/asr/types.ts b/riva-ts-client/src/client/asr/types.ts new file mode 100644 index 00000000..fd637041 --- /dev/null +++ b/riva-ts-client/src/client/asr/types.ts @@ -0,0 +1,128 @@ +import * as grpc from '@grpc/grpc-js'; + +export enum AudioEncoding { + ENCODING_UNSPECIFIED = 0, + LINEAR_PCM = 1, + FLAC = 2, + MULAW = 3, + ALAW = 4 +} + +export interface SpeechContext { + phrases: string[]; + boost: number; +} + +export interface EndpointingConfig { + startHistory?: number; + startThreshold?: number; + stopHistory?: number; + stopHistoryEou?: number; + stopThreshold?: number; + stopThresholdEou?: number; +} + +export interface SpeakerDiarizationConfig { + enableSpeakerDiarization: boolean; + minSpeakerCount?: number; + maxSpeakerCount?: number; +} + +export interface RecognitionConfig { + encoding: AudioEncoding; + sampleRateHertz: number; + languageCode: string; + audioChannelCount?: number; + maxAlternatives?: number; + profanityFilter?: boolean; + enableAutomaticPunctuation?: boolean; + enableWordTimeOffsets?: boolean; + enableWordConfidence?: boolean; + enableRawTranscript?: boolean; + enableSpeakerDiarization?: boolean; + diarizationConfig?: SpeakerDiarizationConfig; + endpointingConfig?: EndpointingConfig; + speechContexts?: SpeechContext[]; + customConfiguration?: Record; + model?: string; +} + +export interface StreamingRecognitionConfig { + config: RecognitionConfig; + interimResults?: boolean; + singleUtterance?: boolean; +} + +export interface WordInfo { + startTime: number; + endTime: number; + word: string; + confidence: number; + speakerTag?: string; +} + +export interface SpeechRecognitionAlternative { + transcript: string; + confidence: number; + words: WordInfo[]; +} + +export interface SpeechRecognitionResult { + alternatives: SpeechRecognitionAlternative[]; + channelTag: number; + languageCode: string; + isPartial?: boolean; +} + +export interface StreamingRecognizeResponse { + results: SpeechRecognitionResult[]; + speechEventType?: 'END_OF_SINGLE_UTTERANCE' | 'SPEECH_ACTIVITY_BEGIN' | 'SPEECH_ACTIVITY_END'; + timeOffset?: number; + audioContent?: Buffer; +} + +export interface RecognizeResponse { + results: SpeechRecognitionResult[]; +} + +export interface AudioChunk { + audioContent: Uint8Array; + timeOffset?: number; +} + +export interface WavFileParameters { + nframes: number; + framerate: number; + duration: number; + nchannels: number; + sampwidth: number; + dataOffset: number; +} + +export interface AudioSource { + content?: Uint8Array; + [Symbol.asyncIterator]?(): AsyncIterator; + [Symbol.iterator]?(): Iterator; +} + +export interface AudioContentSource { + content: Uint8Array; +} + +export interface ASRModel { + name: string; + languages: string[]; + sample_rate: number; + streaming_supported: boolean; +} + +export interface ListModelsResponse { + models: ASRModel[]; +} + +export interface ASRServiceClient { + config: RecognitionConfig; + streamingRecognize(metadata?: grpc.Metadata): grpc.ClientDuplexStream; + recognize(request: { config: RecognitionConfig; audio: { content: Uint8Array } }): Promise; + listModels(request: {}): Promise; +} diff --git a/riva-ts-client/src/client/asr/utils.ts b/riva-ts-client/src/client/asr/utils.ts new file mode 100644 index 00000000..f8025d14 --- /dev/null +++ b/riva-ts-client/src/client/asr/utils.ts @@ -0,0 +1,123 @@ +import { Writable } from 'stream'; +import { StreamingRecognizeResponse, RecognizeResponse, AudioEncoding } from './types'; +import * as wav from 'node-wav'; +import * as fs from 'fs'; + +export type PrintMode = 'no' | 'time' | 'confidence'; + +export interface WavFileParameters { + encoding: AudioEncoding; + sampleRate: number; +} + +/** + * Get WAV file parameters + */ +export async function getWavFileParameters(filePath: string): Promise { + const buffer = fs.readFileSync(filePath); + const result = wav.decode(buffer); + + return { + encoding: AudioEncoding.LINEAR_PCM, + sampleRate: result.sampleRate + }; +} + +/** + * Print streaming recognition results + */ +export function printStreaming( + responses: AsyncIterable, + outputStreams: Writable[] = [process.stdout], + additionalInfo: PrintMode = 'no', + wordTimeOffsets = false, + showIntermediate = false, + fileMode = 'w' +): void { + const write = (text: string) => { + for (const stream of outputStreams) { + stream.write(text + '\n'); + } + }; + + (async () => { + for await (const response of responses) { + if (!response.results.length) continue; + + for (const result of response.results) { + if (!showIntermediate && result.isPartial) continue; + + const prefix = result.isPartial ? '>>' : '##'; + let text = `${prefix} ${result.alternatives[0].transcript}`; + + if (additionalInfo === 'time') { + text = `${response.timeOffset?.toFixed(2)}s ${text}`; + } else if (additionalInfo === 'confidence') { + text = `${result.alternatives[0].confidence.toFixed(2)} ${text}`; + } + + write(text); + + if (wordTimeOffsets && result.alternatives[0].words.length > 0) { + const words = result.alternatives[0].words.map(word => { + const info = [word.word]; + if (word.startTime !== undefined) { + info.push(`${word.startTime.toFixed(2)}s`); + } + if (word.endTime !== undefined) { + info.push(`${word.endTime.toFixed(2)}s`); + } + if (word.confidence !== undefined) { + info.push(`${word.confidence.toFixed(2)}`); + } + if (word.speakerLabel !== undefined) { + info.push(`speaker:${word.speakerLabel}`); + } + return info.join(' '); + }); + write(` ${words.join(' | ')}`); + } + } + } + })(); +} + +/** + * Print offline recognition results + */ +export function printOffline( + response: RecognizeResponse, + outputStreams: Writable[] = [process.stdout] +): void { + const write = (text: string) => { + for (const stream of outputStreams) { + stream.write(text + '\n'); + } + }; + + if (!response.results.length) return; + + for (const result of response.results) { + write(`## ${result.alternatives[0].transcript}`); + + if (result.alternatives[0].words.length > 0) { + const words = result.alternatives[0].words.map(word => { + const info = [word.word]; + if (word.startTime !== undefined) { + info.push(`${word.startTime.toFixed(2)}s`); + } + if (word.endTime !== undefined) { + info.push(`${word.endTime.toFixed(2)}s`); + } + if (word.confidence !== undefined) { + info.push(`${word.confidence.toFixed(2)}`); + } + if (word.speakerLabel !== undefined) { + info.push(`speaker:${word.speakerLabel}`); + } + return info.join(' '); + }); + write(` ${words.join(' | ')}`); + } + } +} diff --git a/riva-ts-client/src/client/audio/io.ts b/riva-ts-client/src/client/audio/io.ts new file mode 100644 index 00000000..dc6bdab6 --- /dev/null +++ b/riva-ts-client/src/client/audio/io.ts @@ -0,0 +1,236 @@ +import { EventEmitter } from 'events'; +import { + AudioDeviceInfo, + AudioDeviceManager, + AudioStream, + AudioStreamCallbacks, + AudioStreamConfig, + MicrophoneStreamOptions, + SoundCallbackOptions +} from './types'; + +/** + * Manages audio devices and provides information about them + */ +export class AudioDeviceManagerImpl implements AudioDeviceManager { + private readonly audioContext: AudioContext; + + constructor() { + this.audioContext = new AudioContext(); + } + + async getDeviceInfo(deviceId: number): Promise { + const devices = await navigator.mediaDevices.enumerateDevices(); + const device = devices[deviceId]; + if (!device) { + throw new Error(`Device with ID ${deviceId} not found`); + } + + return { + index: deviceId, + name: device.label, + maxInputChannels: device.kind === 'audioinput' ? 1 : 0, + maxOutputChannels: device.kind === 'audiooutput' ? 2 : 0, + defaultSampleRate: this.audioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }; + } + + async getDefaultInputDeviceInfo(): Promise { + const devices = await navigator.mediaDevices.enumerateDevices(); + const device = devices.find(d => d.kind === 'audioinput'); + if (!device) { + return null; + } + + return { + index: 0, + name: device.label, + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: this.audioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }; + } + + async listOutputDevices(): Promise { + const devices = await navigator.mediaDevices.enumerateDevices(); + return devices + .filter(d => d.kind === 'audiooutput') + .map((device, index) => ({ + index, + name: device.label, + maxInputChannels: 0, + maxOutputChannels: 2, + defaultSampleRate: this.audioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + })); + } + + async listInputDevices(): Promise { + const devices = await navigator.mediaDevices.enumerateDevices(); + return devices + .filter(d => d.kind === 'audioinput') + .map((device, index) => ({ + index, + name: device.label, + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: this.audioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + })); + } +} + +/** + * Handles microphone input streaming + */ +export class MicrophoneStream extends EventEmitter implements AudioStream { + private readonly options: MicrophoneStreamOptions; + private mediaStream?: MediaStream; + private audioContext?: AudioContext; + private sourceNode?: MediaStreamAudioSourceNode; + private processorNode?: ScriptProcessorNode; + private active: boolean = false; + + constructor(options: MicrophoneStreamOptions) { + super(); + this.options = options; + } + + async start(): Promise { + if (this.active) { + return; + } + + try { + this.mediaStream = await navigator.mediaDevices.getUserMedia({ + audio: { + deviceId: this.options.device ? { exact: String(this.options.device) } : undefined, + sampleRate: this.options.rate, + channelCount: 1 + } + }); + + this.audioContext = new AudioContext({ + sampleRate: this.options.rate + }); + + this.sourceNode = this.audioContext.createMediaStreamSource(this.mediaStream); + this.processorNode = this.audioContext.createScriptProcessor( + this.options.chunk, + 1, + 1 + ); + + this.processorNode.onaudioprocess = (e) => { + const buffer = e.inputBuffer.getChannelData(0); + this.emit('data', Buffer.from(buffer.buffer)); + }; + + this.sourceNode.connect(this.processorNode); + this.processorNode.connect(this.audioContext.destination); + this.active = true; + } catch (error) { + this.emit('error', error instanceof Error ? error : new Error(String(error))); + } + } + + stop(): void { + if (!this.active) { + return; + } + + if (this.processorNode) { + this.processorNode.disconnect(); + this.processorNode = undefined; + } + + if (this.sourceNode) { + this.sourceNode.disconnect(); + this.sourceNode = undefined; + } + + if (this.mediaStream) { + this.mediaStream.getTracks().forEach(track => track.stop()); + this.mediaStream = undefined; + } + + if (this.audioContext) { + this.audioContext.close(); + this.audioContext = undefined; + } + + this.active = false; + this.emit('close'); + } + + pause(): void { + if (this.mediaStream) { + this.mediaStream.getTracks().forEach(track => track.enabled = false); + } + } + + resume(): void { + if (this.mediaStream) { + this.mediaStream.getTracks().forEach(track => track.enabled = true); + } + } + + isActive(): boolean { + return this.active; + } +} + +/** + * Handles audio output + */ +export class SoundCallback { + private readonly audioContext: AudioContext; + private readonly options: SoundCallbackOptions; + private opened: boolean = true; + + constructor(options: SoundCallbackOptions) { + this.options = options; + this.audioContext = new AudioContext({ + sampleRate: options.framerate, + latencyHint: 'interactive' + }); + } + + async write(audioData: Buffer): Promise { + if (!this.opened) { + throw new Error('Sound callback is closed'); + } + + const arrayBuffer = audioData.buffer.slice( + audioData.byteOffset, + audioData.byteOffset + audioData.byteLength + ); + + const audioBuffer = await this.audioContext.decodeAudioData(arrayBuffer); + const source = this.audioContext.createBufferSource(); + source.buffer = audioBuffer; + source.connect(this.audioContext.destination); + source.start(); + } + + close(): void { + if (this.opened) { + this.audioContext.close(); + this.opened = false; + } + } +} diff --git a/riva-ts-client/src/client/audio/transforms.ts b/riva-ts-client/src/client/audio/transforms.ts new file mode 100644 index 00000000..89b17b42 --- /dev/null +++ b/riva-ts-client/src/client/audio/transforms.ts @@ -0,0 +1,138 @@ +import { Transform, TransformCallback } from 'stream'; +import { AudioEncoding } from '../asr/types'; + +export interface AudioTransformOptions { + sourceSampleRate: number; + targetSampleRate: number; + sourceChannels: number; + targetChannels: number; + sourceEncoding: AudioEncoding; + targetEncoding: AudioEncoding; +} + +/** + * Base class for audio transformations + */ +export abstract class AudioTransform extends Transform { + protected readonly options: AudioTransformOptions; + + constructor(options: AudioTransformOptions) { + super(); + this.options = options; + } + + abstract _transform(chunk: Buffer, encoding: string, callback: TransformCallback): void; +} + +/** + * Transforms audio sample rate + */ +export class SampleRateTransform extends AudioTransform { + private remainder: Buffer = Buffer.alloc(0); + + _transform(chunk: Buffer, _encoding: string, callback: TransformCallback): void { + if (this.options.sourceSampleRate === this.options.targetSampleRate) { + this.push(chunk); + callback(); + return; + } + + // Combine with remainder from previous chunk + const buffer = Buffer.concat([this.remainder, chunk]); + const ratio = this.options.targetSampleRate / this.options.sourceSampleRate; + const bytesPerSample = 2; // Assuming 16-bit audio + const samplesPerChannel = Math.floor(buffer.length / (bytesPerSample * this.options.sourceChannels)); + const targetSamples = Math.floor(samplesPerChannel * ratio); + const targetSize = targetSamples * bytesPerSample * this.options.targetChannels; + + // Process complete samples + if (targetSize > 0) { + const resampledData = this.resample( + buffer.slice(0, samplesPerChannel * bytesPerSample * this.options.sourceChannels), + ratio + ); + this.push(resampledData); + + // Save remainder for next chunk + this.remainder = buffer.slice(samplesPerChannel * bytesPerSample * this.options.sourceChannels); + } else { + this.remainder = buffer; + } + + callback(); + } + + private resample(buffer: Buffer, ratio: number): Buffer { + // Simple linear interpolation - for production, use a proper resampling library + const result = Buffer.alloc(Math.floor(buffer.length * ratio)); + const bytesPerSample = 2; + const samplesPerChannel = buffer.length / (bytesPerSample * this.options.sourceChannels); + + for (let i = 0; i < Math.floor(samplesPerChannel * ratio); i++) { + const sourceIdx = Math.floor(i / ratio); + for (let channel = 0; channel < this.options.targetChannels; channel++) { + const value = buffer.readInt16LE(sourceIdx * bytesPerSample * this.options.sourceChannels + channel * bytesPerSample); + result.writeInt16LE(value, i * bytesPerSample * this.options.targetChannels + channel * bytesPerSample); + } + } + + return result; + } +} + +/** + * Transforms number of audio channels + */ +export class ChannelTransform extends AudioTransform { + _transform(chunk: Buffer, _encoding: string, callback: TransformCallback): void { + if (this.options.sourceChannels === this.options.targetChannels) { + this.push(chunk); + callback(); + return; + } + + const bytesPerSample = 2; // Assuming 16-bit audio + const samplesPerChannel = chunk.length / (bytesPerSample * this.options.sourceChannels); + const result = Buffer.alloc(samplesPerChannel * bytesPerSample * this.options.targetChannels); + + for (let i = 0; i < samplesPerChannel; i++) { + if (this.options.sourceChannels > this.options.targetChannels) { + // Downmix channels (average) + let sum = 0; + for (let ch = 0; ch < this.options.sourceChannels; ch++) { + sum += chunk.readInt16LE(i * bytesPerSample * this.options.sourceChannels + ch * bytesPerSample); + } + const avg = Math.round(sum / this.options.sourceChannels); + result.writeInt16LE(avg, i * bytesPerSample); + } else { + // Upmix channels (duplicate) + const value = chunk.readInt16LE(i * bytesPerSample * this.options.sourceChannels); + for (let ch = 0; ch < this.options.targetChannels; ch++) { + result.writeInt16LE(value, i * bytesPerSample * this.options.targetChannels + ch * bytesPerSample); + } + } + } + + this.push(result); + callback(); + } +} + +/** + * Creates a transform stream pipeline for audio processing + */ +export function createAudioTransformPipeline(options: AudioTransformOptions): Transform { + const transforms: Transform[] = []; + + // Add necessary transforms in the correct order + if (options.sourceSampleRate !== options.targetSampleRate) { + transforms.push(new SampleRateTransform(options)); + } + + if (options.sourceChannels !== options.targetChannels) { + transforms.push(new ChannelTransform(options)); + } + + // Chain transforms + return transforms.reduce((prev, curr) => prev.pipe(curr)); +} diff --git a/riva-ts-client/src/client/audio/types.ts b/riva-ts-client/src/client/audio/types.ts new file mode 100644 index 00000000..ae7a0273 --- /dev/null +++ b/riva-ts-client/src/client/audio/types.ts @@ -0,0 +1,58 @@ +export interface AudioDeviceInfo { + index: number; + name: string; + maxInputChannels: number; + maxOutputChannels: number; + defaultSampleRate: number; + defaultLowInputLatency: number; + defaultLowOutputLatency: number; + defaultHighInputLatency: number; + defaultHighOutputLatency: number; +} + +export interface MicrophoneStreamOptions { + rate: number; + chunk: number; + device?: number; +} + +export interface AudioStreamCallbacks { + onData?: (chunk: Buffer) => void; + onError?: (error: Error) => void; + onClose?: () => void; +} + +export interface SoundCallbackOptions { + outputDeviceIndex?: number; + sampwidth: number; + nchannels: number; + framerate: number; +} + +export interface AudioStreamConfig { + format: number; + channels: number; + rate: number; + framesPerBuffer: number; + inputDevice?: number; + outputDevice?: number; +} + +export interface AudioDeviceManager { + getDeviceInfo(deviceId: number): Promise; + getDefaultInputDeviceInfo(): Promise; + listOutputDevices(): Promise; + listInputDevices(): Promise; +} + +export interface AudioStream { + start(): void; + stop(): void; + pause(): void; + resume(): void; + isActive(): boolean; + on(event: 'data', callback: (chunk: Buffer) => void): void; + on(event: 'error', callback: (error: Error) => void): void; + on(event: 'close', callback: () => void): void; + off(event: string, callback: Function): void; +} diff --git a/riva-ts-client/src/client/audio/wav.ts b/riva-ts-client/src/client/audio/wav.ts new file mode 100644 index 00000000..c06734b3 --- /dev/null +++ b/riva-ts-client/src/client/audio/wav.ts @@ -0,0 +1,137 @@ +import { createReadStream } from 'fs'; +import { promisify } from 'util'; +import { Transform, TransformCallback } from 'stream'; +import { Readable } from 'stream'; + +export interface WavFileParameters { + nframes: number; + framerate: number; + duration: number; + nchannels: number; + sampwidth: number; + dataOffset: number; +} + +export class WavHeaderError extends Error { + constructor(message: string) { + super(message); + this.name = 'WavHeaderError'; + } +} + +/** + * Extracts WAV file parameters from a file + * @param filePath Path to WAV file + */ +export async function getWavFileParameters(filePath: string): Promise { + const buffer = Buffer.alloc(44); // Standard WAV header size + const stream = createReadStream(filePath, { start: 0, end: 43 }) as Readable; + + return new Promise((resolve, reject) => { + stream.on('error', reject); + stream.on('data', (chunk: Buffer) => { + try { + // Verify RIFF header + if (chunk.toString('ascii', 0, 4) !== 'RIFF') { + throw new WavHeaderError('Not a valid WAV file: missing RIFF header'); + } + + // Verify WAVE format + if (chunk.toString('ascii', 8, 12) !== 'WAVE') { + throw new WavHeaderError('Not a valid WAV file: missing WAVE format'); + } + + // Get format chunk + if (chunk.toString('ascii', 12, 16) !== 'fmt ') { + throw new WavHeaderError('Not a valid WAV file: missing fmt chunk'); + } + + const params: WavFileParameters = { + nchannels: chunk.readUInt16LE(22), + framerate: chunk.readUInt32LE(24), + sampwidth: chunk.readUInt16LE(34) / 8, + dataOffset: 44, // Standard WAV header size + nframes: chunk.readUInt32LE(40) / (chunk.readUInt16LE(22) * chunk.readUInt16LE(34) / 8), + duration: 0 // Will be calculated + }; + + params.duration = params.nframes / params.framerate; + resolve(params); + } catch (error) { + reject(error); + } finally { + stream.destroy(); + } + }); + }); +} + +/** + * Transform stream that splits audio data into chunks + */ +export class AudioChunkTransform extends Transform { + private readonly chunkSize: number; + private buffer: Buffer; + + constructor(chunkSize: number) { + super(); + this.chunkSize = chunkSize; + this.buffer = Buffer.alloc(0); + } + + _transform(chunk: Buffer, _encoding: string, callback: TransformCallback): void { + // Append new data to buffer + this.buffer = Buffer.concat([this.buffer, chunk]); + + // Process complete chunks + while (this.buffer.length >= this.chunkSize) { + const chunkData = this.buffer.slice(0, this.chunkSize); + this.push(chunkData); + this.buffer = this.buffer.slice(this.chunkSize); + } + + callback(); + } + + _flush(callback: TransformCallback): void { + // Push remaining data if any + if (this.buffer.length > 0) { + this.push(this.buffer); + } + callback(); + } +} + +/** + * Creates an async iterator for audio chunks from a WAV file + */ +export async function* createAudioChunkIterator( + filePath: string, + chunkFrames: number, + params?: WavFileParameters +): AsyncGenerator<{ audioContent: Buffer; timeOffset: number }> { + // Get WAV parameters if not provided + const wavParams = params || await getWavFileParameters(filePath); + const chunkSize = chunkFrames * wavParams.nchannels * wavParams.sampwidth; + + // Create read stream starting after WAV header + const stream = createReadStream(filePath, { + start: wavParams.dataOffset, + highWaterMark: chunkSize + }); + + // Create transform stream for chunking + const chunker = new AudioChunkTransform(chunkSize); + stream.pipe(chunker); + + let timeOffset = 0; + const frameTime = 1 / wavParams.framerate; + + for await (const chunk of chunker) { + yield { + audioContent: chunk, + timeOffset + }; + timeOffset += chunkFrames * frameTime; + } +} diff --git a/riva-ts-client/src/client/auth/index.ts b/riva-ts-client/src/client/auth/index.ts new file mode 100644 index 00000000..9f679826 --- /dev/null +++ b/riva-ts-client/src/client/auth/index.ts @@ -0,0 +1,139 @@ +import * as grpc from '@grpc/grpc-js'; +import { readFileSync } from 'fs'; +import { resolve } from 'path'; + +export interface AuthOptions { + uri?: string; + useSsl?: boolean; + sslCert?: string; + credentials?: grpc.ChannelCredentials; + metadata?: Array<[string, string]>; + apiKey?: string; + channelOptions?: grpc.ChannelOptions; +} + +/** + * Creates a gRPC channel with the specified authentication settings + */ +function createChannel( + sslCert?: string, + useSsl: boolean = false, + uri: string = 'localhost:50051', + metadata?: Array<[string, string]>, + channelOptions: grpc.ChannelOptions = {} +): grpc.Channel { + const metadataCallback = (_params: any, callback: Function) => { + const grpcMetadata = new grpc.Metadata(); + if (metadata) { + for (const [key, value] of metadata) { + grpcMetadata.add(key, value); + } + } + callback(null, grpcMetadata); + }; + + if (sslCert || useSsl) { + let rootCertificates: Buffer | null = null; + if (sslCert) { + const certPath = resolve(sslCert); + rootCertificates = readFileSync(certPath); + } + + let creds = grpc.credentials.createSsl(rootCertificates); + if (metadata) { + const callCreds = grpc.credentials.createFromMetadataGenerator(metadataCallback); + creds = grpc.credentials.combineChannelCredentials(creds, callCreds); + } + return new grpc.Channel(uri, creds, channelOptions); + } + + return new grpc.Channel(uri, grpc.credentials.createInsecure(), channelOptions); +} + +export class Auth { + private readonly uri: string; + private readonly useSsl: boolean; + private readonly sslCert: string | undefined; + private readonly metadata: Array<[string, string]>; + private readonly channelOptions: grpc.ChannelOptions; + public readonly channel: grpc.Channel; + + constructor(options: AuthOptions); + constructor( + sslCert?: string, + useSsl?: boolean, + uri?: string, + metadataArgs?: string[][] + ); + constructor( + optionsOrSslCert?: AuthOptions | string, + useSsl: boolean = false, + uri: string = 'localhost:50051', + metadataArgs?: string[][] + ) { + if (typeof optionsOrSslCert === 'object') { + // AuthOptions constructor + const options = optionsOrSslCert; + this.uri = options.uri || 'localhost:50051'; + this.useSsl = options.useSsl || false; + this.sslCert = options.sslCert; + this.channelOptions = options.channelOptions || {}; + + // Combine provided metadata with API key if present + this.metadata = [...(options.metadata || [])]; + if (options.apiKey) { + this.metadata.push(['api-key', options.apiKey]); + } + } else { + // Python-style constructor + this.uri = uri; + this.useSsl = useSsl; + this.sslCert = optionsOrSslCert; + this.channelOptions = {}; + this.metadata = []; + + if (metadataArgs) { + for (const meta of metadataArgs) { + if (meta.length !== 2) { + throw new Error(`Metadata should have 2 parameters in "key" "value" pair. Received ${meta.length} parameters.`); + } + this.metadata.push([meta[0], meta[1]]); + } + } + } + + this.channel = createChannel( + this.sslCert, + this.useSsl, + this.uri, + this.metadata, + this.channelOptions + ); + } + + /** + * Gets metadata for gRPC calls + */ + getCallMetadata(): grpc.Metadata { + const metadata = new grpc.Metadata(); + for (const [key, value] of this.metadata) { + metadata.add(key, value); + } + return metadata; + } + + /** + * Alias for getCallMetadata to maintain Python compatibility + */ + getAuthMetadata(): Array<[string, string]> { + return this.metadata; + } + + /** + * Creates a gRPC channel with the current settings + * @deprecated Use the channel property instead + */ + createChannel(): grpc.Channel { + return this.channel; + } +} diff --git a/riva-ts-client/src/client/base.ts b/riva-ts-client/src/client/base.ts new file mode 100644 index 00000000..8d92c271 --- /dev/null +++ b/riva-ts-client/src/client/base.ts @@ -0,0 +1,60 @@ +import * as grpc from '@grpc/grpc-js'; +import { Auth, AuthOptions } from './auth/index'; +import { RivaConfig } from './types'; +import { Logger, createLogger, format, transports } from 'winston'; + +export abstract class BaseClient { + protected readonly auth: Auth; + protected readonly channel: grpc.Channel; + protected readonly logger: Logger; + + constructor(config: RivaConfig) { + const authOptions: AuthOptions = { + uri: config.serverUrl, + useSsl: config.auth?.ssl || false, + apiKey: config.auth?.apiKey, + metadata: config.auth?.metadata, + credentials: config.auth?.credentials, + sslCert: config.auth?.sslCert + }; + + this.auth = new Auth(authOptions); + this.channel = this.auth.createChannel(); + + // Set up logging with winston + this.logger = createLogger({ + level: config.logging?.level || 'info', + format: config.logging?.format === 'json' ? format.json() : format.simple(), + transports: [new transports.Console()] + }); + } + + /** + * Closes the gRPC channel + */ + close(): Promise { + return new Promise((resolve, reject) => { + try { + this.channel.close(); + resolve(); + } catch (err) { + reject(err); + } + }); + } + + /** + * Gets call metadata for gRPC requests + */ + protected getCallMetadata(): grpc.Metadata { + return this.auth.getCallMetadata(); + } + + /** + * Creates gRPC deadline from timeout in milliseconds + */ + protected createDeadline(timeoutMs?: number): Date | undefined { + if (!timeoutMs) return undefined; + return new Date(Date.now() + timeoutMs); + } +} diff --git a/riva-ts-client/src/client/errors.ts b/riva-ts-client/src/client/errors.ts new file mode 100644 index 00000000..b2ec731c --- /dev/null +++ b/riva-ts-client/src/client/errors.ts @@ -0,0 +1,54 @@ +import * as grpc from '@grpc/grpc-js'; + +export class RivaError extends Error { + constructor( + message: string, + public readonly code?: grpc.status, + public readonly details?: string + ) { + super(message); + this.name = 'RivaError'; + } + + static fromGrpcError(error: Error & { code?: grpc.status; details?: string }): RivaError { + return new RivaError( + error.message, + error.code, + error.details + ); + } +} + +export class AuthenticationError extends RivaError { + constructor(message: string, details?: string) { + super(message, grpc.status.UNAUTHENTICATED, details); + this.name = 'AuthenticationError'; + } +} + +export class ConnectionError extends RivaError { + constructor(message: string, details?: string) { + super(message, grpc.status.UNAVAILABLE, details); + this.name = 'ConnectionError'; + } +} + +export class InvalidArgumentError extends RivaError { + constructor(message: string, details?: string) { + super(message, grpc.status.INVALID_ARGUMENT, details); + this.name = 'InvalidArgumentError'; + } +} + +export function handleGrpcError(error: Error & { code?: grpc.status }): never { + switch (error.code) { + case grpc.status.UNAUTHENTICATED: + throw new AuthenticationError(error.message); + case grpc.status.UNAVAILABLE: + throw new ConnectionError(error.message); + case grpc.status.INVALID_ARGUMENT: + throw new InvalidArgumentError(error.message); + default: + throw RivaError.fromGrpcError(error); + } +} diff --git a/riva-ts-client/src/client/index.ts b/riva-ts-client/src/client/index.ts new file mode 100644 index 00000000..3fce163a --- /dev/null +++ b/riva-ts-client/src/client/index.ts @@ -0,0 +1,6 @@ +export * from './auth'; +export * from './base'; +export * from './asr'; +export * from './nlp'; +export * from './tts'; +export { NeuralMachineTranslationService } from './nmt'; diff --git a/riva-ts-client/src/client/nlp/index.ts b/riva-ts-client/src/client/nlp/index.ts new file mode 100644 index 00000000..9575f328 --- /dev/null +++ b/riva-ts-client/src/client/nlp/index.ts @@ -0,0 +1,211 @@ +import * as grpc from '@grpc/grpc-js'; +import { BaseClient } from '../base'; +import { handleGrpcError } from '../errors'; +import { getProtoClient } from '../utils/proto'; +import { + ClassifyRequest, + TokenClassifyRequest, + AnalyzeEntitiesRequest, + AnalyzeIntentRequest, + AnalyzeIntentResponse, + TransformTextRequest, + ClassifyResponse, + TokenClassifyResponse, + TransformTextResponse, + NaturalQueryRequest, + NaturalQueryResponse, + AnalyzeEntitiesResponse, + RivaNLPServiceClientImpl +} from '../../proto/riva_nlp'; + +/** + * Provides text classification, token classification, text transformation, + * intent recognition, punctuation and capitalization restoring, + * question answering services. + */ +export class NLPService extends BaseClient { + private readonly client: RivaNLPServiceClientImpl; + + constructor(config: { serverUrl: string; auth?: any }) { + super(config); + const { RivaNLPServiceClient } = getProtoClient('riva_nlp'); + this.client = new RivaNLPServiceClient( + config.serverUrl, + config.auth?.credentials || grpc.credentials.createInsecure() + ); + } + + /** + * Classifies text provided in inputStrings. For example, this method can be used for + * intent classification. + */ + async classifyText( + inputStrings: string | string[], + modelName: string, + languageCode: string = 'en-US', + future: boolean = false + ): Promise { + try { + const texts = Array.isArray(inputStrings) ? inputStrings : [inputStrings]; + const request: ClassifyRequest = { + text: texts, + model: { + modelName, + languageCode + } + }; + + return await this.client.Classify(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Classifies tokens in the text provided in inputStrings. For example, this method can be used for + * named entity recognition. + */ + async classifyTokens( + inputStrings: string | string[], + modelName: string, + languageCode: string = 'en-US', + future: boolean = false + ): Promise { + try { + const texts = Array.isArray(inputStrings) ? inputStrings : [inputStrings]; + const request: TokenClassifyRequest = { + text: texts, + model: { + modelName, + languageCode + } + }; + + return await this.client.TokenClassify(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Transforms text provided in inputString. For example, this method can be used for + * text normalization. + */ + async transformText( + inputString: string, + modelName: string, + future: boolean = false + ): Promise { + try { + const request: TransformTextRequest = { + text: inputString, + model: modelName + }; + + return await this.client.TransformText(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Restores punctuation and capitalization in the text provided in inputString. + */ + async punctuateText( + inputString: string, + modelName: string, + future: boolean = false + ): Promise { + try { + const request: TransformTextRequest = { + text: inputString, + model: modelName + }; + + return await this.client.TransformText(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Analyzes entities in the text provided in inputString. + */ + async analyzeEntities( + inputString: string, + modelName: string, + languageCode: string = 'en-US', + future: boolean = false + ): Promise { + try { + const request: AnalyzeEntitiesRequest = { + text: inputString + }; + + return await this.client.AnalyzeEntities(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Analyzes intent in the text provided in inputString. + */ + async analyzeIntent( + inputString: string, + modelName: string, + languageCode: string = 'en-US', + future: boolean = false + ): Promise { + try { + const request: AnalyzeIntentRequest = { + text: inputString + }; + + return await this.client.AnalyzeIntent(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } + + /** + * Performs natural language query using the provided query and context. + */ + async naturalQuery( + query: string, + context: string, + future: boolean = false + ): Promise { + try { + const request: NaturalQueryRequest = { + query, + context + }; + + return await this.client.NaturalQuery(request); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new Error('Unknown error occurred'); + } + } +} diff --git a/riva-ts-client/src/client/nlp/types.ts b/riva-ts-client/src/client/nlp/types.ts new file mode 100644 index 00000000..d2ced05b --- /dev/null +++ b/riva-ts-client/src/client/nlp/types.ts @@ -0,0 +1,176 @@ +import * as grpc from '@grpc/grpc-js'; +import { + ClassifyResponse, + TokenClassifyResponse, + TransformTextResponse, + AnalyzeIntentResponse, + NaturalQueryResponse +} from '../../proto/riva_nlp'; + +/** + * Interface for the NLP service client that matches the Python implementation + */ +export interface NLPServiceClient extends grpc.Client { + /** + * Classifies text provided in inputStrings. For example, this method can be used for + * intent classification. + */ + classifyText( + inputStrings: string | string[], + modelName: string, + languageCode?: string, + future?: boolean + ): Promise; + + /** + * Classifies tokens in texts in inputStrings. Can be used for slot classification or NER. + */ + classifyTokens( + inputStrings: string | string[], + modelName: string, + languageCode?: string, + future?: boolean + ): Promise; + + /** + * The behavior of the function is defined entirely by the underlying model and may be used for + * tasks like translation, adding punctuation, augment the input directly, etc. + */ + transformText( + inputStrings: string | string[], + modelName: string, + languageCode?: string, + future?: boolean + ): Promise; + + /** + * Takes text with no- or limited- punctuation and returns the same text with corrected punctuation and + * capitalization. + */ + punctuateText( + inputStrings: string | string[], + languageCode?: string, + future?: boolean + ): Promise; + + /** + * Accepts an input string and returns all named entities within the text, as well as a category and likelihood. + */ + analyzeEntities( + inputString: string, + languageCode?: string, + future?: boolean + ): Promise; + + /** + * Accepts an input string and returns the most likely intent as well as slots relevant to that intent. + */ + analyzeIntent( + inputString: string, + options?: any, + future?: boolean + ): Promise; + + /** + * A search function that enables querying one or more documents or contexts with a query that is written in + * natural language. + */ + naturalQuery( + query: string, + context: string, + topN?: number, + future?: boolean + ): Promise; +} + +// Utility functions for result extraction +export function extractAllTextClassesAndConfidences( + response: any +): [string[][], number[][]] { + const textClasses: string[][] = []; + const confidences: number[][] = []; + + for (const result of response.results) { + textClasses.push(result.labels.map((lbl: any) => lbl.className)); + confidences.push(result.labels.map((lbl: any) => lbl.score)); + } + + return [textClasses, confidences]; +} + +export function extractMostProbableTextClassAndConfidence( + response: any +): [string[], number[]] { + const [intents, confidences] = extractAllTextClassesAndConfidences(response); + return [intents.map(x => x[0]), confidences.map(x => x[0])]; +} + +export function extractAllTokenClassificationPredictions( + response: any +): [string[][], string[][][], number[][][], number[][][], number[][][]] { + const tokens: string[][] = []; + const tokenClasses: string[][][] = []; + const confidences: number[][][] = []; + const starts: number[][][] = []; + const ends: number[][][] = []; + + for (const batchResult of response.results) { + const elemTokens: string[] = []; + const elemTokenClasses: string[][] = []; + const elemConfidences: number[][] = []; + const elemStarts: number[][] = []; + const elemEnds: number[][] = []; + + for (const result of batchResult.results) { + elemTokens.push(result.token); + elemTokenClasses.push(result.label.map((lbl: any) => lbl.className)); + elemConfidences.push(result.label.map((lbl: any) => lbl.score)); + elemStarts.push(result.span.map((span: any) => span.start)); + elemEnds.push(result.span.map((span: any) => span.end)); + } + + tokens.push(elemTokens); + tokenClasses.push(elemTokenClasses); + confidences.push(elemConfidences); + starts.push(elemStarts); + ends.push(elemEnds); + } + + return [tokens, tokenClasses, confidences, starts, ends]; +} + +export function extractMostProbableTokenClassificationPredictions( + response: any +): [string[][], string[][], number[][], number[][], number[][]] { + const [tokens, tokenClasses, confidences, starts, ends] = extractAllTokenClassificationPredictions(response); + return [ + tokens, + tokenClasses.map(x => x.map(xx => xx[0])), + confidences.map(x => x.map(xx => xx[0])), + starts.map(x => x.map(xx => xx[0])), + ends.map(x => x.map(xx => xx[0])) + ]; +} + +export function extractAllTransformedTexts(response: any): string[] { + return response.text; +} + +export function extractMostProbableTransformedText(response: any): string { + return response.text[0]; +} + +export function prepareTransformTextRequest( + inputStrings: string | string[], + modelName: string, + languageCode: string = 'en-US' +): any { + const texts = Array.isArray(inputStrings) ? inputStrings : [inputStrings]; + return { + text: texts, + model: { + modelName, + languageCode + } + }; +} diff --git a/riva-ts-client/src/client/nmt/index.ts b/riva-ts-client/src/client/nmt/index.ts new file mode 100644 index 00000000..0a08376d --- /dev/null +++ b/riva-ts-client/src/client/nmt/index.ts @@ -0,0 +1,133 @@ +import * as grpc from '@grpc/grpc-js'; +import { BaseClient } from '../base'; +import { handleGrpcError } from '../errors'; +import { getProtoClient } from '../utils/proto'; +import { + StreamingS2SRequest, + StreamingS2SResponse, + StreamingS2TRequest, + StreamingS2TResponse, + TranslateRequest, + TranslateResponse, + AvailableLanguageRequest, + AvailableLanguageResponse, + NMTServiceClient, + StreamingS2SConfig, + StreamingS2TConfig, + ClientConfig +} from './types'; + +/** + * Generator for streaming speech-to-speech translation requests + */ +function* streaming_s2s_request_generator( + audioChunks: Iterable, + streamingConfig: StreamingS2SConfig +): Generator { + yield { config: streamingConfig }; + for (const chunk of audioChunks) { + yield { audioContent: chunk }; + } +} + +/** + * Generator for streaming speech-to-text translation requests + */ +function* streaming_s2t_request_generator( + audioChunks: Iterable, + streamingConfig: StreamingS2TConfig +): Generator { + yield { config: streamingConfig }; + for (const chunk of audioChunks) { + yield { audioContent: chunk }; + } +} + +/** + * Neural Machine Translation Service for text and speech translation + */ +export class NeuralMachineTranslationService extends BaseClient { + private readonly stub: NMTServiceClient; + + constructor(config: ClientConfig) { + super(config); + const { RivaSpeechTranslationClient } = getProtoClient('riva_services'); + this.stub = new RivaSpeechTranslationClient( + config.serverUrl, + config.auth?.credentials || grpc.credentials.createInsecure() + ) as NMTServiceClient; + } + + /** + * Generates speech to speech translation responses for fragments of speech audio + */ + async *streaming_s2s_response_generator( + audioChunks: Iterable, + streamingConfig: StreamingS2SConfig + ): AsyncGenerator { + try { + const generator = streaming_s2s_request_generator(audioChunks, streamingConfig); + const stream = this.stub.streamingTranslateSpeechToSpeech(generator, this.getCallMetadata()); + + for await (const response of stream) { + yield response; + } + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } + + /** + * Generates speech to text translation responses for fragments of speech audio + */ + async *streaming_s2t_response_generator( + audioChunks: Iterable, + streamingConfig: StreamingS2TConfig + ): AsyncGenerator { + try { + const generator = streaming_s2t_request_generator(audioChunks, streamingConfig); + const stream = this.stub.streamingTranslateSpeechToText(generator, this.getCallMetadata()); + + for await (const response of stream) { + yield response; + } + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } + + /** + * Translates text from one language to another + */ + async translate(request: TranslateRequest): Promise { + try { + return await this.stub.translateText(request, this.getCallMetadata()); + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } + + /** + * Gets supported language pairs for a model + */ + async get_supported_language_pairs(model: string): Promise { + try { + return await this.stub.listSupportedLanguagePairs({ model }, this.getCallMetadata()); + } catch (err) { + const error = err as Error; + throw handleGrpcError(error); + } + } +} + +export type { + TranslateRequest, + TranslateResponse, + StreamingS2SConfig, + StreamingS2TConfig, + AvailableLanguageRequest, + AvailableLanguageResponse +}; diff --git a/riva-ts-client/src/client/nmt/types.ts b/riva-ts-client/src/client/nmt/types.ts new file mode 100644 index 00000000..fda7cc12 --- /dev/null +++ b/riva-ts-client/src/client/nmt/types.ts @@ -0,0 +1,329 @@ +import * as grpc from '@grpc/grpc-js'; +import { AudioEncoding } from '../asr/types'; +import { RivaConfig } from '../types'; + +/** + * Configuration options for the client. + */ +export interface ClientConfig extends RivaConfig { + /** + * The URL of the server to connect to. + */ + serverUrl: string; + + /** + * Optional function to create a gRPC client for testing + */ + createClient?: (url: string, credentials: grpc.ChannelCredentials) => grpc.Client; + // Add other configuration options as needed +} + +/** + * Configuration options for speech recognition. + */ +export interface RecognitionConfig { + /** + * Whether to enable automatic punctuation. + */ + enableAutomaticPunctuation?: boolean; + /** + * The audio encoding of the input audio. + */ + audioEncoding?: AudioEncoding; + /** + * The sample rate of the input audio in Hz. + */ + sampleRateHertz?: number; + /** + * The language code of the input audio. + */ + languageCode?: string; +} + +/** + * Configuration options for streaming speech recognition. + */ +export interface StreamingRecognitionConfig { + /** + * The recognition configuration. + */ + config: RecognitionConfig; + /** + * Whether to return interim results. + */ + interimResults: boolean; +} + +/** + * Configuration options for text translation. + */ +export interface TranslationConfig { + /** + * The source language code. + */ + sourceLanguageCode: string; + /** + * The target language code. + */ + targetLanguageCode: string; + /** + * The model to use for translation. + */ + model?: string; + /** + * The phrases to not translate. + */ + doNotTranslatePhrases?: string[]; +} + +/** + * Configuration options for speech synthesis. + */ +export interface SynthesizeSpeechConfig { + /** + * The sample rate of the output audio in Hz. + */ + sampleRateHz: number; + /** + * The voice to use for synthesis. + */ + voiceName?: string; + /** + * The language code of the output audio. + */ + languageCode?: string; +} + +/** + * Configuration options for streaming speech-to-speech translation. + */ +export interface StreamingS2SConfig { + /** + * The ASR configuration. + */ + asrConfig: StreamingRecognitionConfig; + /** + * The translation configuration. + */ + translationConfig: TranslationConfig; + /** + * The TTS configuration. + */ + ttsConfig: SynthesizeSpeechConfig; +} + +/** + * Configuration options for streaming speech-to-text translation. + */ +export interface StreamingS2TConfig { + /** + * The ASR configuration. + */ + asrConfig: StreamingRecognitionConfig; + /** + * The translation configuration. + */ + translationConfig: TranslationConfig; +} + +/** + * Request message for streaming speech-to-speech translation. + */ +export interface StreamingS2SRequest { + /** + * The configuration for the request. + */ + config?: StreamingS2SConfig; + /** + * The audio content for the request. + */ + audioContent?: Uint8Array; +} + +/** + * Request message for streaming speech-to-text translation. + */ +export interface StreamingS2TRequest { + /** + * The configuration for the request. + */ + config?: StreamingS2TConfig; + /** + * The audio content for the request. + */ + audioContent?: Uint8Array; +} + +/** + * Response message for streaming speech-to-speech translation. + */ +export interface StreamingS2SResponse { + /** + * The result of the translation. + */ + result: { + /** + * The transcript of the input audio. + */ + transcript: string; + /** + * The translation of the input audio. + */ + translation: string; + /** + * The synthesized audio content. + */ + audioContent: Uint8Array; + /** + * Whether the result is partial. + */ + isPartial: boolean; + }; +} + +/** + * Response message for streaming speech-to-text translation. + */ +export interface StreamingS2TResponse { + /** + * The result of the translation. + */ + result: { + /** + * The transcript of the input audio. + */ + transcript: string; + /** + * The translation of the input audio. + */ + translation: string; + /** + * Whether the result is partial. + */ + isPartial: boolean; + }; +} + +/** + * Request message for text translation. + */ +export interface TranslateRequest { + /** + * The text to translate. + */ + text: string; + /** + * The source language. + */ + sourceLanguage: string; + /** + * The target language. + */ + targetLanguage: string; + /** + * The model to use for translation. + */ + model?: string; + /** + * The phrases to not translate. + */ + doNotTranslatePhrases?: string[]; +} + +/** + * Response message for text translation. + */ +export interface TranslateResponse { + /** + * The translations. + */ + translations: Array<{ + /** + * The translated text. + */ + text: string; + /** + * The confidence score of the translation. + */ + score: number; + }>; + /** + * The translated text. + */ + text: string; + /** + * The confidence score of the translation. + */ + score: number; +} + +/** + * A language pair. + */ +export interface LanguagePair { + /** + * The source language code. + */ + sourceLanguageCode: string; + /** + * The target language code. + */ + targetLanguageCode: string; +} + +/** + * Request message for listing supported language pairs. + */ +export interface AvailableLanguageRequest { + /** + * The model to use for listing language pairs. + */ + model: string; +} + +/** + * Response message for listing supported language pairs. + */ +export interface AvailableLanguageResponse { + /** + * The supported language pairs. + */ + supportedLanguagePairs: LanguagePair[]; +} + +/** + * The NMT service client. + */ +export interface NMTServiceClient { + /** + * Translates text. + */ + translateText( + request: TranslateRequest, + metadata?: grpc.Metadata + ): Promise; + + /** + * Lists supported language pairs. + */ + listSupportedLanguagePairs( + request: AvailableLanguageRequest, + metadata?: grpc.Metadata + ): Promise; + + /** + * Streams speech-to-speech translation. + */ + streamingTranslateSpeechToSpeech( + request: Generator | StreamingS2SRequest, + metadata?: grpc.Metadata + ): grpc.ClientReadableStream; + + /** + * Streams speech-to-text translation. + */ + streamingTranslateSpeechToText( + request: Generator | StreamingS2TRequest, + metadata?: grpc.Metadata + ): grpc.ClientReadableStream; +} diff --git a/riva-ts-client/src/client/package_info.ts b/riva-ts-client/src/client/package_info.ts new file mode 100644 index 00000000..796eefc6 --- /dev/null +++ b/riva-ts-client/src/client/package_info.ts @@ -0,0 +1,36 @@ +export const VERSION = { + MAJOR: 2, + MINOR: 18, + PATCH: 0, + PRE_RELEASE: 'rc0' +} as const; + +export const PACKAGE_INFO = { + shortversion: `${VERSION.MAJOR}.${VERSION.MINOR}.${VERSION.PATCH}`, + version: `${VERSION.MAJOR}.${VERSION.MINOR}.${VERSION.PATCH}${VERSION.PRE_RELEASE}`, + packageName: 'nvidia-riva-client', + contactNames: 'Anton Peganov', + contactEmails: 'apeganov@nvidia.com', + homepage: 'https://github.com/nvidia-riva/python-clients', + repositoryUrl: 'https://github.com/nvidia-riva/python-clients', + downloadUrl: 'https://github.com/nvidia-riva/python-clients/releases', + description: 'TypeScript implementation of the Riva Client API', + license: 'MIT', + keywords: [ + 'deep learning', + 'machine learning', + 'gpu', + 'NLP', + 'ASR', + 'TTS', + 'NMT', + 'nvidia', + 'speech', + 'language', + 'Riva', + 'client' + ], + rivaVersion: '2.18.0', + rivaRelease: '24.12', + rivaModelsVersion: '2.18.0' +} as const; diff --git a/riva-ts-client/src/client/tts/index.ts b/riva-ts-client/src/client/tts/index.ts new file mode 100644 index 00000000..ef3d78f4 --- /dev/null +++ b/riva-ts-client/src/client/tts/index.ts @@ -0,0 +1,216 @@ +import * as grpc from '@grpc/grpc-js'; +import * as fs from 'fs'; +import { WaveFile } from 'wavefile'; +import { BaseClient } from '../base'; +import { RivaError, handleGrpcError } from '../errors'; +import { AudioEncoding } from '../asr/types'; +import { RivaConfig } from '../types'; +import { getProtoClient } from '../utils/proto'; +import { + SynthesizeSpeechRequest, + SynthesizeSpeechResponse, + ZeroShotData, + RivaSpeechSynthesisStub, + WaveFile as WaveFileType, + RivaSynthesisConfigResponse +} from './types'; + +function convertSamplesToBuffer(samples: Float64Array | Float64Array[]): Buffer { + if (Array.isArray(samples)) { + // Multi-channel audio + const flatSamples = new Float64Array(samples.reduce((acc: number[], channel) => { + channel.forEach(sample => acc.push(sample)); + return acc; + }, [])); + return Buffer.from(flatSamples.buffer); + } else { + // Single channel audio + return Buffer.from(samples.buffer); + } +} + +/** + * Add custom dictionary to synthesis request config + */ +function addCustomDictionaryToConfig( + req: SynthesizeSpeechRequest, + customDictionary?: Record +): void { + if (customDictionary) { + const resultList = Object.entries(customDictionary).map(([key, value]) => `${key} ${value}`); + if (resultList.length > 0) { + req.customDictionary = resultList.join(','); + } + } +} + +/** + * A class for synthesizing speech from text. Provides synthesize which returns entire audio for a text + * and synthesizeOnline which returns audio in small chunks as it is becoming available. + */ +export class SpeechSynthesisService extends BaseClient { + private readonly stub: RivaSpeechSynthesisStub; + + /** + * Initializes an instance of the class. + * @param config Configuration for the service + */ + constructor(config: RivaConfig) { + super(config); + const { RivaSpeechSynthesisStub } = getProtoClient('riva_services'); + this.stub = new RivaSpeechSynthesisStub( + config.serverUrl, + config.auth?.credentials || grpc.credentials.createInsecure() + ); + } + + /** + * Gets the available voices and their configurations + * @returns Promise with the synthesis configuration response + */ + async getRivaSynthesisConfig(): Promise { + try { + return await this.stub.GetRivaSynthesisConfig({}, this.getCallMetadata()); + } catch (error: unknown) { + if (error instanceof Error) { + throw error; + } + throw new RivaError('Unknown error occurred'); + } + } + + /** + * Synthesizes an entire audio for text. + * @param text An input text. + * @param voiceName A name of the voice, e.g. "English-US-Female-1". If null, server will select first available model. + * @param languageCode A language to use. + * @param encoding An output audio encoding, e.g. AudioEncoding.LINEAR_PCM. + * @param sampleRateHz Number of frames per second in output audio. + * @param audioPromptFile An audio prompt file location for zero shot model. + * @param audioPromptEncoding Encoding of audio prompt file, e.g. AudioEncoding.LINEAR_PCM. + * @param quality This defines the number of times decoder is run. Higher number improves quality but takes longer. + * @param future Whether to return an async result instead of usual response. + * @param customDictionary Dictionary with key-value pair containing grapheme and corresponding phoneme + */ + async synthesize( + text: string, + voiceName?: string, + languageCode: string = 'en-US', + encoding: AudioEncoding = AudioEncoding.LINEAR_PCM, + sampleRateHz: number = 44100, + audioPromptFile?: string, + audioPromptEncoding: AudioEncoding = AudioEncoding.LINEAR_PCM, + quality: number = 20, + future: boolean = false, + customDictionary?: Record + ): Promise { + const req: SynthesizeSpeechRequest = { + text, + languageCode, + sampleRateHz, + encoding + }; + + if (voiceName) { + req.voiceName = voiceName; + } + + if (audioPromptFile) { + const wavFile = new WaveFile(fs.readFileSync(audioPromptFile)) as WaveFileType; + const samples = wavFile.getSamples(); + if (!samples || (Array.isArray(samples) && !samples.length)) { + throw new RivaError('Invalid WAV file: no samples found'); + } + if (!wavFile.fmt?.sampleRate) { + throw new RivaError('Invalid WAV file: no sample rate found'); + } + + const zeroShotData: ZeroShotData = { + audioPrompt: convertSamplesToBuffer(samples), + encoding: audioPromptEncoding, + sampleRateHz: wavFile.fmt.sampleRate, + quality + }; + req.zeroShotData = zeroShotData; + } + + addCustomDictionaryToConfig(req, customDictionary); + + try { + return await this.stub.Synthesize(req, this.getCallMetadata()); + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new RivaError('Unknown error during synthesis'); + } + } + + /** + * Synthesizes and yields output audio chunks for text as the chunks becoming available. + * @param text An input text. + * @param voiceName A name of the voice, e.g. "English-US-Female-1". If null, server will select first available model. + * @param languageCode A language to use. + * @param encoding An output audio encoding, e.g. AudioEncoding.LINEAR_PCM. + * @param sampleRateHz Number of frames per second in output audio. + * @param audioPromptFile An audio prompt file location for zero shot model. + * @param audioPromptEncoding Encoding of audio prompt file, e.g. AudioEncoding.LINEAR_PCM. + * @param quality This defines the number of times decoder is run. Higher number improves quality but takes longer. + * @param customDictionary Dictionary with key-value pair containing grapheme and corresponding phoneme + */ + async *synthesizeOnline( + text: string, + voiceName?: string, + languageCode: string = 'en-US', + encoding: AudioEncoding = AudioEncoding.LINEAR_PCM, + sampleRateHz: number = 44100, + audioPromptFile?: string, + audioPromptEncoding: AudioEncoding = AudioEncoding.LINEAR_PCM, + quality: number = 20, + customDictionary?: Record + ): AsyncGenerator { + const req: SynthesizeSpeechRequest = { + text, + languageCode, + sampleRateHz, + encoding + }; + + if (voiceName) { + req.voiceName = voiceName; + } + + if (audioPromptFile) { + const wavFile = new WaveFile(fs.readFileSync(audioPromptFile)) as WaveFileType; + const samples = wavFile.getSamples(); + if (!samples || (Array.isArray(samples) && !samples.length)) { + throw new RivaError('Invalid WAV file: no samples found'); + } + if (!wavFile.fmt?.sampleRate) { + throw new RivaError('Invalid WAV file: no sample rate found'); + } + + const zeroShotData: ZeroShotData = { + audioPrompt: convertSamplesToBuffer(samples), + encoding: audioPromptEncoding, + sampleRateHz: wavFile.fmt.sampleRate, + quality + }; + req.zeroShotData = zeroShotData; + } + + addCustomDictionaryToConfig(req, customDictionary); + + try { + const stream = this.stub.SynthesizeOnline(req, this.getCallMetadata()); + for await (const response of stream) { + yield response; + } + } catch (error) { + if (error instanceof Error) { + throw handleGrpcError(error); + } + throw new RivaError('Unknown error during streaming synthesis'); + } + } +} diff --git a/riva-ts-client/src/client/tts/types.ts b/riva-ts-client/src/client/tts/types.ts new file mode 100644 index 00000000..43d00504 --- /dev/null +++ b/riva-ts-client/src/client/tts/types.ts @@ -0,0 +1,86 @@ +import * as grpc from '@grpc/grpc-js'; +import { AudioEncoding } from '../asr/types'; + +export interface WaveFmt { + sampleRate: number; +} + +export interface WaveFile { + fmt: WaveFmt; + getSamples(): Float64Array | Float64Array[]; +} + +export interface ZeroShotData { + audioPrompt: Uint8Array; + encoding: AudioEncoding; + sampleRateHz: number; + quality: number; +} + +export interface SynthesizeSpeechRequest { + text: string; + languageCode: string; + sampleRateHz: number; + encoding: AudioEncoding; + voiceName?: string; + zeroShotData?: ZeroShotData; + customDictionary?: string; +} + +export interface SynthesizeSpeechResponse { + audio: Uint8Array; + audioConfig: { + encoding: AudioEncoding; + sampleRateHz: number; + }; +} + +export interface RivaSynthesisConfigRequest { + // Empty request +} + +export interface VoiceParameters { + languageCode: string; + voiceName: string; + subvoices: string; +} + +export interface ModelConfig { + parameters: VoiceParameters; +} + +export interface RivaSynthesisConfigResponse { + modelConfig: ModelConfig[]; +} + +export interface RivaSpeechSynthesisStub extends grpc.Client { + /** + * Synthesizes speech synchronously + */ + Synthesize( + request: SynthesizeSpeechRequest, + metadata: grpc.Metadata, + callback: (error: grpc.ServiceError | null, response: SynthesizeSpeechResponse) => void + ): grpc.ClientUnaryCall; + + Synthesize( + request: SynthesizeSpeechRequest, + metadata: grpc.Metadata + ): Promise; + + /** + * Synthesizes speech in streaming mode, returning chunks as they become available + */ + SynthesizeOnline( + request: SynthesizeSpeechRequest, + metadata: grpc.Metadata + ): grpc.ClientReadableStream; + + /** + * Gets available voice configurations + */ + GetRivaSynthesisConfig( + request: RivaSynthesisConfigRequest, + metadata: grpc.Metadata + ): Promise; +} diff --git a/riva-ts-client/src/client/types.ts b/riva-ts-client/src/client/types.ts new file mode 100644 index 00000000..a2da2fe9 --- /dev/null +++ b/riva-ts-client/src/client/types.ts @@ -0,0 +1,55 @@ +import * as grpc from '@grpc/grpc-js'; + +export interface AuthConfig { + /** + * Whether to use SSL/TLS for the connection + */ + ssl?: boolean; + + /** + * Path to SSL certificate file + */ + sslCert?: string; + + /** + * API key for authentication + */ + apiKey?: string; + + /** + * SSL/TLS credentials + */ + credentials?: grpc.ChannelCredentials; + + /** + * Additional metadata to send with each request + */ + metadata?: [string, string][]; +} + +export interface RivaConfig { + /** + * Riva server URL (e.g., 'localhost:50051') + */ + serverUrl: string; + + /** + * Authentication configuration + */ + auth?: AuthConfig; + + /** + * Logging configuration + */ + logging?: { + /** + * Log level (default: 'info') + */ + level?: string; + + /** + * Log format ('simple' or 'json', default: 'simple') + */ + format?: 'simple' | 'json'; + }; +} diff --git a/riva-ts-client/src/client/types/index.ts b/riva-ts-client/src/client/types/index.ts new file mode 100644 index 00000000..195ef75f --- /dev/null +++ b/riva-ts-client/src/client/types/index.ts @@ -0,0 +1,17 @@ +// Basic types for the Riva client +export interface RivaConfig { + serverUrl: string; + auth?: { + ssl?: boolean; + apiKey?: string; + }; + logging?: { + level: 'debug' | 'info' | 'warn' | 'error'; + }; +} + +export interface AudioConfig { + sampleRateHz: number; + encoding: 'LINEAR16' | 'FLAC' | 'MULAW' | 'ALAW'; + languageCode?: string; +} diff --git a/riva-ts-client/src/client/utils/proto.ts b/riva-ts-client/src/client/utils/proto.ts new file mode 100644 index 00000000..d4381f35 --- /dev/null +++ b/riva-ts-client/src/client/utils/proto.ts @@ -0,0 +1,10 @@ +/** + * Utility to handle proto imports across different environments + */ +export const getProtoClient = (name: string) => { + try { + return require(`../../src/proto/${name}`); + } catch (e) { + return require(`../../proto/${name}`); + } +}; diff --git a/riva-ts-client/src/index.ts b/riva-ts-client/src/index.ts new file mode 100644 index 00000000..0eaa5fb0 --- /dev/null +++ b/riva-ts-client/src/index.ts @@ -0,0 +1,5 @@ +// Main entry point for the Riva TypeScript client +export * from './client/types'; + +// Placeholder for initial setup +console.log('Riva TypeScript Client initialized'); diff --git a/riva-ts-client/src/proto/riva_asr.ts b/riva-ts-client/src/proto/riva_asr.ts new file mode 100644 index 00000000..7ab70d74 --- /dev/null +++ b/riva-ts-client/src/proto/riva_asr.ts @@ -0,0 +1,718 @@ +// Code generated by protoc-gen-ts_proto. DO NOT EDIT. +// versions: +// protoc-gen-ts_proto v1.181.2 +// protoc v5.29.3 +// source: riva_asr.proto + +/* eslint-disable */ +import _m0 from "protobufjs/minimal"; +import { Observable } from "rxjs"; +import { map } from "rxjs/operators"; +import { AudioConfig } from "./riva_services"; + +export const protobufPackage = "nvidia.riva"; + +export interface RecognizeRequest { + config: AudioConfig | undefined; + audio: Uint8Array; + model: string; +} + +export interface RecognizeResponse { + results: RecognizeResponse_Result[]; +} + +export interface RecognizeResponse_Result { + transcript: string; + confidence: number; + words: WordInfo[]; +} + +export interface StreamingRecognizeRequest { + config?: AudioConfig | undefined; + audioContent?: Uint8Array | undefined; +} + +export interface StreamingRecognizeResponse { + results: StreamingRecognizeResponse_Result[]; +} + +export interface StreamingRecognizeResponse_Result { + transcript: string; + confidence: number; + isFinal: boolean; + words: WordInfo[]; +} + +export interface WordInfo { + word: string; + startTime: number; + endTime: number; + confidence: number; +} + +function createBaseRecognizeRequest(): RecognizeRequest { + return { config: undefined, audio: new Uint8Array(0), model: "" }; +} + +export const RecognizeRequest = { + encode(message: RecognizeRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.config !== undefined) { + AudioConfig.encode(message.config, writer.uint32(10).fork()).ldelim(); + } + if (message.audio.length !== 0) { + writer.uint32(18).bytes(message.audio); + } + if (message.model !== "") { + writer.uint32(26).string(message.model); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): RecognizeRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseRecognizeRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.config = AudioConfig.decode(reader, reader.uint32()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.audio = reader.bytes(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.model = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): RecognizeRequest { + return { + config: isSet(object.config) ? AudioConfig.fromJSON(object.config) : undefined, + audio: isSet(object.audio) ? bytesFromBase64(object.audio) : new Uint8Array(0), + model: isSet(object.model) ? globalThis.String(object.model) : "", + }; + }, + + toJSON(message: RecognizeRequest): unknown { + const obj: any = {}; + if (message.config !== undefined) { + obj.config = AudioConfig.toJSON(message.config); + } + if (message.audio.length !== 0) { + obj.audio = base64FromBytes(message.audio); + } + if (message.model !== "") { + obj.model = message.model; + } + return obj; + }, + + create, I>>(base?: I): RecognizeRequest { + return RecognizeRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): RecognizeRequest { + const message = createBaseRecognizeRequest(); + message.config = (object.config !== undefined && object.config !== null) + ? AudioConfig.fromPartial(object.config) + : undefined; + message.audio = object.audio ?? new Uint8Array(0); + message.model = object.model ?? ""; + return message; + }, +}; + +function createBaseRecognizeResponse(): RecognizeResponse { + return { results: [] }; +} + +export const RecognizeResponse = { + encode(message: RecognizeResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.results) { + RecognizeResponse_Result.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): RecognizeResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseRecognizeResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.results.push(RecognizeResponse_Result.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): RecognizeResponse { + return { + results: globalThis.Array.isArray(object?.results) + ? object.results.map((e: any) => RecognizeResponse_Result.fromJSON(e)) + : [], + }; + }, + + toJSON(message: RecognizeResponse): unknown { + const obj: any = {}; + if (message.results?.length) { + obj.results = message.results.map((e) => RecognizeResponse_Result.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): RecognizeResponse { + return RecognizeResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): RecognizeResponse { + const message = createBaseRecognizeResponse(); + message.results = object.results?.map((e) => RecognizeResponse_Result.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseRecognizeResponse_Result(): RecognizeResponse_Result { + return { transcript: "", confidence: 0, words: [] }; +} + +export const RecognizeResponse_Result = { + encode(message: RecognizeResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.transcript !== "") { + writer.uint32(10).string(message.transcript); + } + if (message.confidence !== 0) { + writer.uint32(21).float(message.confidence); + } + for (const v of message.words) { + WordInfo.encode(v!, writer.uint32(26).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): RecognizeResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseRecognizeResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.transcript = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.confidence = reader.float(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.words.push(WordInfo.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): RecognizeResponse_Result { + return { + transcript: isSet(object.transcript) ? globalThis.String(object.transcript) : "", + confidence: isSet(object.confidence) ? globalThis.Number(object.confidence) : 0, + words: globalThis.Array.isArray(object?.words) ? object.words.map((e: any) => WordInfo.fromJSON(e)) : [], + }; + }, + + toJSON(message: RecognizeResponse_Result): unknown { + const obj: any = {}; + if (message.transcript !== "") { + obj.transcript = message.transcript; + } + if (message.confidence !== 0) { + obj.confidence = message.confidence; + } + if (message.words?.length) { + obj.words = message.words.map((e) => WordInfo.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): RecognizeResponse_Result { + return RecognizeResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): RecognizeResponse_Result { + const message = createBaseRecognizeResponse_Result(); + message.transcript = object.transcript ?? ""; + message.confidence = object.confidence ?? 0; + message.words = object.words?.map((e) => WordInfo.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseStreamingRecognizeRequest(): StreamingRecognizeRequest { + return { config: undefined, audioContent: undefined }; +} + +export const StreamingRecognizeRequest = { + encode(message: StreamingRecognizeRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.config !== undefined) { + AudioConfig.encode(message.config, writer.uint32(10).fork()).ldelim(); + } + if (message.audioContent !== undefined) { + writer.uint32(18).bytes(message.audioContent); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingRecognizeRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingRecognizeRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.config = AudioConfig.decode(reader, reader.uint32()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.audioContent = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingRecognizeRequest { + return { + config: isSet(object.config) ? AudioConfig.fromJSON(object.config) : undefined, + audioContent: isSet(object.audioContent) ? bytesFromBase64(object.audioContent) : undefined, + }; + }, + + toJSON(message: StreamingRecognizeRequest): unknown { + const obj: any = {}; + if (message.config !== undefined) { + obj.config = AudioConfig.toJSON(message.config); + } + if (message.audioContent !== undefined) { + obj.audioContent = base64FromBytes(message.audioContent); + } + return obj; + }, + + create, I>>(base?: I): StreamingRecognizeRequest { + return StreamingRecognizeRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingRecognizeRequest { + const message = createBaseStreamingRecognizeRequest(); + message.config = (object.config !== undefined && object.config !== null) + ? AudioConfig.fromPartial(object.config) + : undefined; + message.audioContent = object.audioContent ?? undefined; + return message; + }, +}; + +function createBaseStreamingRecognizeResponse(): StreamingRecognizeResponse { + return { results: [] }; +} + +export const StreamingRecognizeResponse = { + encode(message: StreamingRecognizeResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.results) { + StreamingRecognizeResponse_Result.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingRecognizeResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingRecognizeResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.results.push(StreamingRecognizeResponse_Result.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingRecognizeResponse { + return { + results: globalThis.Array.isArray(object?.results) + ? object.results.map((e: any) => StreamingRecognizeResponse_Result.fromJSON(e)) + : [], + }; + }, + + toJSON(message: StreamingRecognizeResponse): unknown { + const obj: any = {}; + if (message.results?.length) { + obj.results = message.results.map((e) => StreamingRecognizeResponse_Result.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): StreamingRecognizeResponse { + return StreamingRecognizeResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingRecognizeResponse { + const message = createBaseStreamingRecognizeResponse(); + message.results = object.results?.map((e) => StreamingRecognizeResponse_Result.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseStreamingRecognizeResponse_Result(): StreamingRecognizeResponse_Result { + return { transcript: "", confidence: 0, isFinal: false, words: [] }; +} + +export const StreamingRecognizeResponse_Result = { + encode(message: StreamingRecognizeResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.transcript !== "") { + writer.uint32(10).string(message.transcript); + } + if (message.confidence !== 0) { + writer.uint32(21).float(message.confidence); + } + if (message.isFinal !== false) { + writer.uint32(24).bool(message.isFinal); + } + for (const v of message.words) { + WordInfo.encode(v!, writer.uint32(34).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingRecognizeResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingRecognizeResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.transcript = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.confidence = reader.float(); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.isFinal = reader.bool(); + continue; + case 4: + if (tag !== 34) { + break; + } + + message.words.push(WordInfo.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingRecognizeResponse_Result { + return { + transcript: isSet(object.transcript) ? globalThis.String(object.transcript) : "", + confidence: isSet(object.confidence) ? globalThis.Number(object.confidence) : 0, + isFinal: isSet(object.isFinal) ? globalThis.Boolean(object.isFinal) : false, + words: globalThis.Array.isArray(object?.words) ? object.words.map((e: any) => WordInfo.fromJSON(e)) : [], + }; + }, + + toJSON(message: StreamingRecognizeResponse_Result): unknown { + const obj: any = {}; + if (message.transcript !== "") { + obj.transcript = message.transcript; + } + if (message.confidence !== 0) { + obj.confidence = message.confidence; + } + if (message.isFinal !== false) { + obj.isFinal = message.isFinal; + } + if (message.words?.length) { + obj.words = message.words.map((e) => WordInfo.toJSON(e)); + } + return obj; + }, + + create, I>>( + base?: I, + ): StreamingRecognizeResponse_Result { + return StreamingRecognizeResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): StreamingRecognizeResponse_Result { + const message = createBaseStreamingRecognizeResponse_Result(); + message.transcript = object.transcript ?? ""; + message.confidence = object.confidence ?? 0; + message.isFinal = object.isFinal ?? false; + message.words = object.words?.map((e) => WordInfo.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseWordInfo(): WordInfo { + return { word: "", startTime: 0, endTime: 0, confidence: 0 }; +} + +export const WordInfo = { + encode(message: WordInfo, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.word !== "") { + writer.uint32(10).string(message.word); + } + if (message.startTime !== 0) { + writer.uint32(21).float(message.startTime); + } + if (message.endTime !== 0) { + writer.uint32(29).float(message.endTime); + } + if (message.confidence !== 0) { + writer.uint32(37).float(message.confidence); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): WordInfo { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseWordInfo(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.word = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.startTime = reader.float(); + continue; + case 3: + if (tag !== 29) { + break; + } + + message.endTime = reader.float(); + continue; + case 4: + if (tag !== 37) { + break; + } + + message.confidence = reader.float(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): WordInfo { + return { + word: isSet(object.word) ? globalThis.String(object.word) : "", + startTime: isSet(object.startTime) ? globalThis.Number(object.startTime) : 0, + endTime: isSet(object.endTime) ? globalThis.Number(object.endTime) : 0, + confidence: isSet(object.confidence) ? globalThis.Number(object.confidence) : 0, + }; + }, + + toJSON(message: WordInfo): unknown { + const obj: any = {}; + if (message.word !== "") { + obj.word = message.word; + } + if (message.startTime !== 0) { + obj.startTime = message.startTime; + } + if (message.endTime !== 0) { + obj.endTime = message.endTime; + } + if (message.confidence !== 0) { + obj.confidence = message.confidence; + } + return obj; + }, + + create, I>>(base?: I): WordInfo { + return WordInfo.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): WordInfo { + const message = createBaseWordInfo(); + message.word = object.word ?? ""; + message.startTime = object.startTime ?? 0; + message.endTime = object.endTime ?? 0; + message.confidence = object.confidence ?? 0; + return message; + }, +}; + +export interface RivaSpeechRecognition { + Recognize(request: RecognizeRequest): Promise; + StreamingRecognize(request: Observable): Observable; +} + +export const RivaSpeechRecognitionServiceName = "nvidia.riva.RivaSpeechRecognition"; +export class RivaSpeechRecognitionClientImpl implements RivaSpeechRecognition { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || RivaSpeechRecognitionServiceName; + this.rpc = rpc; + this.Recognize = this.Recognize.bind(this); + this.StreamingRecognize = this.StreamingRecognize.bind(this); + } + Recognize(request: RecognizeRequest): Promise { + const data = RecognizeRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "Recognize", data); + return promise.then((data) => RecognizeResponse.decode(_m0.Reader.create(data))); + } + + StreamingRecognize(request: Observable): Observable { + const data = request.pipe(map((request) => StreamingRecognizeRequest.encode(request).finish())); + const result = this.rpc.bidirectionalStreamingRequest(this.service, "StreamingRecognize", data); + return result.pipe(map((data) => StreamingRecognizeResponse.decode(_m0.Reader.create(data)))); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + clientStreamingRequest(service: string, method: string, data: Observable): Promise; + serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable; + bidirectionalStreamingRequest(service: string, method: string, data: Observable): Observable; +} + +function bytesFromBase64(b64: string): Uint8Array { + if ((globalThis as any).Buffer) { + return Uint8Array.from(globalThis.Buffer.from(b64, "base64")); + } else { + const bin = globalThis.atob(b64); + const arr = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; ++i) { + arr[i] = bin.charCodeAt(i); + } + return arr; + } +} + +function base64FromBytes(arr: Uint8Array): string { + if ((globalThis as any).Buffer) { + return globalThis.Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(globalThis.String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/riva-ts-client/src/proto/riva_nlp.ts b/riva-ts-client/src/proto/riva_nlp.ts new file mode 100644 index 00000000..818e33b1 --- /dev/null +++ b/riva-ts-client/src/proto/riva_nlp.ts @@ -0,0 +1,1617 @@ +// Code generated by protoc-gen-ts_proto. DO NOT EDIT. +// versions: +// protoc-gen-ts_proto v1.181.2 +// protoc v5.29.3 +// source: riva_nlp.proto + +/* eslint-disable */ +import _m0 from "protobufjs/minimal"; + +export const protobufPackage = "nvidia.riva"; + +export interface ClassifyRequest { + text: string[]; + model: ClassifyRequest_Model | undefined; +} + +export interface ClassifyRequest_Model { + modelName: string; + languageCode: string; +} + +export interface ClassifyResponse { + results: ClassifyResponse_Result[]; +} + +export interface ClassifyResponse_Result { + label: string; + score: number; +} + +export interface TokenClassifyRequest { + text: string[]; + model: TokenClassifyRequest_Model | undefined; +} + +export interface TokenClassifyRequest_Model { + modelName: string; + languageCode: string; +} + +export interface TokenClassifyResponse { + results: TokenClassifyResponse_Result[]; +} + +export interface TokenClassifyResponse_Token { + text: string; + label: string; + score: number; + start: number; + end: number; +} + +export interface TokenClassifyResponse_Result { + tokens: TokenClassifyResponse_Token[]; +} + +export interface AnalyzeEntitiesRequest { + text: string; +} + +export interface AnalyzeEntitiesResponse { + entities: AnalyzeEntitiesResponse_Entity[]; +} + +export interface AnalyzeEntitiesResponse_Entity { + text: string; + type: string; + score: number; + start: number; + end: number; +} + +export interface AnalyzeIntentRequest { + text: string; +} + +export interface AnalyzeIntentResponse { + intent: string; + confidence: number; + slots: AnalyzeIntentResponse_Slot[]; +} + +export interface AnalyzeIntentResponse_Slot { + text: string; + type: string; + score: number; +} + +export interface TransformTextRequest { + text: string; + model: string; +} + +export interface TransformTextResponse { + text: string; +} + +export interface NaturalQueryRequest { + query: string; + context: string; +} + +export interface NaturalQueryResponse { + response: string; + confidence: number; +} + +function createBaseClassifyRequest(): ClassifyRequest { + return { text: [], model: undefined }; +} + +export const ClassifyRequest = { + encode(message: ClassifyRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.text) { + writer.uint32(10).string(v!); + } + if (message.model !== undefined) { + ClassifyRequest_Model.encode(message.model, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): ClassifyRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseClassifyRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text.push(reader.string()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.model = ClassifyRequest_Model.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): ClassifyRequest { + return { + text: globalThis.Array.isArray(object?.text) ? object.text.map((e: any) => globalThis.String(e)) : [], + model: isSet(object.model) ? ClassifyRequest_Model.fromJSON(object.model) : undefined, + }; + }, + + toJSON(message: ClassifyRequest): unknown { + const obj: any = {}; + if (message.text?.length) { + obj.text = message.text; + } + if (message.model !== undefined) { + obj.model = ClassifyRequest_Model.toJSON(message.model); + } + return obj; + }, + + create, I>>(base?: I): ClassifyRequest { + return ClassifyRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ClassifyRequest { + const message = createBaseClassifyRequest(); + message.text = object.text?.map((e) => e) || []; + message.model = (object.model !== undefined && object.model !== null) + ? ClassifyRequest_Model.fromPartial(object.model) + : undefined; + return message; + }, +}; + +function createBaseClassifyRequest_Model(): ClassifyRequest_Model { + return { modelName: "", languageCode: "" }; +} + +export const ClassifyRequest_Model = { + encode(message: ClassifyRequest_Model, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.modelName !== "") { + writer.uint32(10).string(message.modelName); + } + if (message.languageCode !== "") { + writer.uint32(18).string(message.languageCode); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): ClassifyRequest_Model { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseClassifyRequest_Model(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.modelName = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.languageCode = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): ClassifyRequest_Model { + return { + modelName: isSet(object.modelName) ? globalThis.String(object.modelName) : "", + languageCode: isSet(object.languageCode) ? globalThis.String(object.languageCode) : "", + }; + }, + + toJSON(message: ClassifyRequest_Model): unknown { + const obj: any = {}; + if (message.modelName !== "") { + obj.modelName = message.modelName; + } + if (message.languageCode !== "") { + obj.languageCode = message.languageCode; + } + return obj; + }, + + create, I>>(base?: I): ClassifyRequest_Model { + return ClassifyRequest_Model.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ClassifyRequest_Model { + const message = createBaseClassifyRequest_Model(); + message.modelName = object.modelName ?? ""; + message.languageCode = object.languageCode ?? ""; + return message; + }, +}; + +function createBaseClassifyResponse(): ClassifyResponse { + return { results: [] }; +} + +export const ClassifyResponse = { + encode(message: ClassifyResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.results) { + ClassifyResponse_Result.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): ClassifyResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseClassifyResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.results.push(ClassifyResponse_Result.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): ClassifyResponse { + return { + results: globalThis.Array.isArray(object?.results) + ? object.results.map((e: any) => ClassifyResponse_Result.fromJSON(e)) + : [], + }; + }, + + toJSON(message: ClassifyResponse): unknown { + const obj: any = {}; + if (message.results?.length) { + obj.results = message.results.map((e) => ClassifyResponse_Result.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): ClassifyResponse { + return ClassifyResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ClassifyResponse { + const message = createBaseClassifyResponse(); + message.results = object.results?.map((e) => ClassifyResponse_Result.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseClassifyResponse_Result(): ClassifyResponse_Result { + return { label: "", score: 0 }; +} + +export const ClassifyResponse_Result = { + encode(message: ClassifyResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.label !== "") { + writer.uint32(10).string(message.label); + } + if (message.score !== 0) { + writer.uint32(21).float(message.score); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): ClassifyResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseClassifyResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.label = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.score = reader.float(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): ClassifyResponse_Result { + return { + label: isSet(object.label) ? globalThis.String(object.label) : "", + score: isSet(object.score) ? globalThis.Number(object.score) : 0, + }; + }, + + toJSON(message: ClassifyResponse_Result): unknown { + const obj: any = {}; + if (message.label !== "") { + obj.label = message.label; + } + if (message.score !== 0) { + obj.score = message.score; + } + return obj; + }, + + create, I>>(base?: I): ClassifyResponse_Result { + return ClassifyResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): ClassifyResponse_Result { + const message = createBaseClassifyResponse_Result(); + message.label = object.label ?? ""; + message.score = object.score ?? 0; + return message; + }, +}; + +function createBaseTokenClassifyRequest(): TokenClassifyRequest { + return { text: [], model: undefined }; +} + +export const TokenClassifyRequest = { + encode(message: TokenClassifyRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.text) { + writer.uint32(10).string(v!); + } + if (message.model !== undefined) { + TokenClassifyRequest_Model.encode(message.model, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TokenClassifyRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTokenClassifyRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text.push(reader.string()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.model = TokenClassifyRequest_Model.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TokenClassifyRequest { + return { + text: globalThis.Array.isArray(object?.text) ? object.text.map((e: any) => globalThis.String(e)) : [], + model: isSet(object.model) ? TokenClassifyRequest_Model.fromJSON(object.model) : undefined, + }; + }, + + toJSON(message: TokenClassifyRequest): unknown { + const obj: any = {}; + if (message.text?.length) { + obj.text = message.text; + } + if (message.model !== undefined) { + obj.model = TokenClassifyRequest_Model.toJSON(message.model); + } + return obj; + }, + + create, I>>(base?: I): TokenClassifyRequest { + return TokenClassifyRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TokenClassifyRequest { + const message = createBaseTokenClassifyRequest(); + message.text = object.text?.map((e) => e) || []; + message.model = (object.model !== undefined && object.model !== null) + ? TokenClassifyRequest_Model.fromPartial(object.model) + : undefined; + return message; + }, +}; + +function createBaseTokenClassifyRequest_Model(): TokenClassifyRequest_Model { + return { modelName: "", languageCode: "" }; +} + +export const TokenClassifyRequest_Model = { + encode(message: TokenClassifyRequest_Model, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.modelName !== "") { + writer.uint32(10).string(message.modelName); + } + if (message.languageCode !== "") { + writer.uint32(18).string(message.languageCode); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TokenClassifyRequest_Model { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTokenClassifyRequest_Model(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.modelName = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.languageCode = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TokenClassifyRequest_Model { + return { + modelName: isSet(object.modelName) ? globalThis.String(object.modelName) : "", + languageCode: isSet(object.languageCode) ? globalThis.String(object.languageCode) : "", + }; + }, + + toJSON(message: TokenClassifyRequest_Model): unknown { + const obj: any = {}; + if (message.modelName !== "") { + obj.modelName = message.modelName; + } + if (message.languageCode !== "") { + obj.languageCode = message.languageCode; + } + return obj; + }, + + create, I>>(base?: I): TokenClassifyRequest_Model { + return TokenClassifyRequest_Model.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TokenClassifyRequest_Model { + const message = createBaseTokenClassifyRequest_Model(); + message.modelName = object.modelName ?? ""; + message.languageCode = object.languageCode ?? ""; + return message; + }, +}; + +function createBaseTokenClassifyResponse(): TokenClassifyResponse { + return { results: [] }; +} + +export const TokenClassifyResponse = { + encode(message: TokenClassifyResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.results) { + TokenClassifyResponse_Result.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TokenClassifyResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTokenClassifyResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.results.push(TokenClassifyResponse_Result.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TokenClassifyResponse { + return { + results: globalThis.Array.isArray(object?.results) + ? object.results.map((e: any) => TokenClassifyResponse_Result.fromJSON(e)) + : [], + }; + }, + + toJSON(message: TokenClassifyResponse): unknown { + const obj: any = {}; + if (message.results?.length) { + obj.results = message.results.map((e) => TokenClassifyResponse_Result.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): TokenClassifyResponse { + return TokenClassifyResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TokenClassifyResponse { + const message = createBaseTokenClassifyResponse(); + message.results = object.results?.map((e) => TokenClassifyResponse_Result.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseTokenClassifyResponse_Token(): TokenClassifyResponse_Token { + return { text: "", label: "", score: 0, start: 0, end: 0 }; +} + +export const TokenClassifyResponse_Token = { + encode(message: TokenClassifyResponse_Token, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.label !== "") { + writer.uint32(18).string(message.label); + } + if (message.score !== 0) { + writer.uint32(29).float(message.score); + } + if (message.start !== 0) { + writer.uint32(32).int32(message.start); + } + if (message.end !== 0) { + writer.uint32(40).int32(message.end); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TokenClassifyResponse_Token { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTokenClassifyResponse_Token(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.label = reader.string(); + continue; + case 3: + if (tag !== 29) { + break; + } + + message.score = reader.float(); + continue; + case 4: + if (tag !== 32) { + break; + } + + message.start = reader.int32(); + continue; + case 5: + if (tag !== 40) { + break; + } + + message.end = reader.int32(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TokenClassifyResponse_Token { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + label: isSet(object.label) ? globalThis.String(object.label) : "", + score: isSet(object.score) ? globalThis.Number(object.score) : 0, + start: isSet(object.start) ? globalThis.Number(object.start) : 0, + end: isSet(object.end) ? globalThis.Number(object.end) : 0, + }; + }, + + toJSON(message: TokenClassifyResponse_Token): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.label !== "") { + obj.label = message.label; + } + if (message.score !== 0) { + obj.score = message.score; + } + if (message.start !== 0) { + obj.start = Math.round(message.start); + } + if (message.end !== 0) { + obj.end = Math.round(message.end); + } + return obj; + }, + + create, I>>(base?: I): TokenClassifyResponse_Token { + return TokenClassifyResponse_Token.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TokenClassifyResponse_Token { + const message = createBaseTokenClassifyResponse_Token(); + message.text = object.text ?? ""; + message.label = object.label ?? ""; + message.score = object.score ?? 0; + message.start = object.start ?? 0; + message.end = object.end ?? 0; + return message; + }, +}; + +function createBaseTokenClassifyResponse_Result(): TokenClassifyResponse_Result { + return { tokens: [] }; +} + +export const TokenClassifyResponse_Result = { + encode(message: TokenClassifyResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.tokens) { + TokenClassifyResponse_Token.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TokenClassifyResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTokenClassifyResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.tokens.push(TokenClassifyResponse_Token.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TokenClassifyResponse_Result { + return { + tokens: globalThis.Array.isArray(object?.tokens) + ? object.tokens.map((e: any) => TokenClassifyResponse_Token.fromJSON(e)) + : [], + }; + }, + + toJSON(message: TokenClassifyResponse_Result): unknown { + const obj: any = {}; + if (message.tokens?.length) { + obj.tokens = message.tokens.map((e) => TokenClassifyResponse_Token.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): TokenClassifyResponse_Result { + return TokenClassifyResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TokenClassifyResponse_Result { + const message = createBaseTokenClassifyResponse_Result(); + message.tokens = object.tokens?.map((e) => TokenClassifyResponse_Token.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseAnalyzeEntitiesRequest(): AnalyzeEntitiesRequest { + return { text: "" }; +} + +export const AnalyzeEntitiesRequest = { + encode(message: AnalyzeEntitiesRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeEntitiesRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeEntitiesRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeEntitiesRequest { + return { text: isSet(object.text) ? globalThis.String(object.text) : "" }; + }, + + toJSON(message: AnalyzeEntitiesRequest): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + return obj; + }, + + create, I>>(base?: I): AnalyzeEntitiesRequest { + return AnalyzeEntitiesRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AnalyzeEntitiesRequest { + const message = createBaseAnalyzeEntitiesRequest(); + message.text = object.text ?? ""; + return message; + }, +}; + +function createBaseAnalyzeEntitiesResponse(): AnalyzeEntitiesResponse { + return { entities: [] }; +} + +export const AnalyzeEntitiesResponse = { + encode(message: AnalyzeEntitiesResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.entities) { + AnalyzeEntitiesResponse_Entity.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeEntitiesResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeEntitiesResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.entities.push(AnalyzeEntitiesResponse_Entity.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeEntitiesResponse { + return { + entities: globalThis.Array.isArray(object?.entities) + ? object.entities.map((e: any) => AnalyzeEntitiesResponse_Entity.fromJSON(e)) + : [], + }; + }, + + toJSON(message: AnalyzeEntitiesResponse): unknown { + const obj: any = {}; + if (message.entities?.length) { + obj.entities = message.entities.map((e) => AnalyzeEntitiesResponse_Entity.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): AnalyzeEntitiesResponse { + return AnalyzeEntitiesResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AnalyzeEntitiesResponse { + const message = createBaseAnalyzeEntitiesResponse(); + message.entities = object.entities?.map((e) => AnalyzeEntitiesResponse_Entity.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseAnalyzeEntitiesResponse_Entity(): AnalyzeEntitiesResponse_Entity { + return { text: "", type: "", score: 0, start: 0, end: 0 }; +} + +export const AnalyzeEntitiesResponse_Entity = { + encode(message: AnalyzeEntitiesResponse_Entity, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.type !== "") { + writer.uint32(18).string(message.type); + } + if (message.score !== 0) { + writer.uint32(29).float(message.score); + } + if (message.start !== 0) { + writer.uint32(32).int32(message.start); + } + if (message.end !== 0) { + writer.uint32(40).int32(message.end); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeEntitiesResponse_Entity { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeEntitiesResponse_Entity(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.type = reader.string(); + continue; + case 3: + if (tag !== 29) { + break; + } + + message.score = reader.float(); + continue; + case 4: + if (tag !== 32) { + break; + } + + message.start = reader.int32(); + continue; + case 5: + if (tag !== 40) { + break; + } + + message.end = reader.int32(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeEntitiesResponse_Entity { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + type: isSet(object.type) ? globalThis.String(object.type) : "", + score: isSet(object.score) ? globalThis.Number(object.score) : 0, + start: isSet(object.start) ? globalThis.Number(object.start) : 0, + end: isSet(object.end) ? globalThis.Number(object.end) : 0, + }; + }, + + toJSON(message: AnalyzeEntitiesResponse_Entity): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.type !== "") { + obj.type = message.type; + } + if (message.score !== 0) { + obj.score = message.score; + } + if (message.start !== 0) { + obj.start = Math.round(message.start); + } + if (message.end !== 0) { + obj.end = Math.round(message.end); + } + return obj; + }, + + create, I>>(base?: I): AnalyzeEntitiesResponse_Entity { + return AnalyzeEntitiesResponse_Entity.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): AnalyzeEntitiesResponse_Entity { + const message = createBaseAnalyzeEntitiesResponse_Entity(); + message.text = object.text ?? ""; + message.type = object.type ?? ""; + message.score = object.score ?? 0; + message.start = object.start ?? 0; + message.end = object.end ?? 0; + return message; + }, +}; + +function createBaseAnalyzeIntentRequest(): AnalyzeIntentRequest { + return { text: "" }; +} + +export const AnalyzeIntentRequest = { + encode(message: AnalyzeIntentRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeIntentRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeIntentRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeIntentRequest { + return { text: isSet(object.text) ? globalThis.String(object.text) : "" }; + }, + + toJSON(message: AnalyzeIntentRequest): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + return obj; + }, + + create, I>>(base?: I): AnalyzeIntentRequest { + return AnalyzeIntentRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AnalyzeIntentRequest { + const message = createBaseAnalyzeIntentRequest(); + message.text = object.text ?? ""; + return message; + }, +}; + +function createBaseAnalyzeIntentResponse(): AnalyzeIntentResponse { + return { intent: "", confidence: 0, slots: [] }; +} + +export const AnalyzeIntentResponse = { + encode(message: AnalyzeIntentResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.intent !== "") { + writer.uint32(10).string(message.intent); + } + if (message.confidence !== 0) { + writer.uint32(21).float(message.confidence); + } + for (const v of message.slots) { + AnalyzeIntentResponse_Slot.encode(v!, writer.uint32(26).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeIntentResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeIntentResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.intent = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.confidence = reader.float(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.slots.push(AnalyzeIntentResponse_Slot.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeIntentResponse { + return { + intent: isSet(object.intent) ? globalThis.String(object.intent) : "", + confidence: isSet(object.confidence) ? globalThis.Number(object.confidence) : 0, + slots: globalThis.Array.isArray(object?.slots) + ? object.slots.map((e: any) => AnalyzeIntentResponse_Slot.fromJSON(e)) + : [], + }; + }, + + toJSON(message: AnalyzeIntentResponse): unknown { + const obj: any = {}; + if (message.intent !== "") { + obj.intent = message.intent; + } + if (message.confidence !== 0) { + obj.confidence = message.confidence; + } + if (message.slots?.length) { + obj.slots = message.slots.map((e) => AnalyzeIntentResponse_Slot.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): AnalyzeIntentResponse { + return AnalyzeIntentResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AnalyzeIntentResponse { + const message = createBaseAnalyzeIntentResponse(); + message.intent = object.intent ?? ""; + message.confidence = object.confidence ?? 0; + message.slots = object.slots?.map((e) => AnalyzeIntentResponse_Slot.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseAnalyzeIntentResponse_Slot(): AnalyzeIntentResponse_Slot { + return { text: "", type: "", score: 0 }; +} + +export const AnalyzeIntentResponse_Slot = { + encode(message: AnalyzeIntentResponse_Slot, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.type !== "") { + writer.uint32(18).string(message.type); + } + if (message.score !== 0) { + writer.uint32(29).float(message.score); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AnalyzeIntentResponse_Slot { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAnalyzeIntentResponse_Slot(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.type = reader.string(); + continue; + case 3: + if (tag !== 29) { + break; + } + + message.score = reader.float(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AnalyzeIntentResponse_Slot { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + type: isSet(object.type) ? globalThis.String(object.type) : "", + score: isSet(object.score) ? globalThis.Number(object.score) : 0, + }; + }, + + toJSON(message: AnalyzeIntentResponse_Slot): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.type !== "") { + obj.type = message.type; + } + if (message.score !== 0) { + obj.score = message.score; + } + return obj; + }, + + create, I>>(base?: I): AnalyzeIntentResponse_Slot { + return AnalyzeIntentResponse_Slot.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AnalyzeIntentResponse_Slot { + const message = createBaseAnalyzeIntentResponse_Slot(); + message.text = object.text ?? ""; + message.type = object.type ?? ""; + message.score = object.score ?? 0; + return message; + }, +}; + +function createBaseTransformTextRequest(): TransformTextRequest { + return { text: "", model: "" }; +} + +export const TransformTextRequest = { + encode(message: TransformTextRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.model !== "") { + writer.uint32(18).string(message.model); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TransformTextRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTransformTextRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.model = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TransformTextRequest { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + model: isSet(object.model) ? globalThis.String(object.model) : "", + }; + }, + + toJSON(message: TransformTextRequest): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.model !== "") { + obj.model = message.model; + } + return obj; + }, + + create, I>>(base?: I): TransformTextRequest { + return TransformTextRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TransformTextRequest { + const message = createBaseTransformTextRequest(); + message.text = object.text ?? ""; + message.model = object.model ?? ""; + return message; + }, +}; + +function createBaseTransformTextResponse(): TransformTextResponse { + return { text: "" }; +} + +export const TransformTextResponse = { + encode(message: TransformTextResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TransformTextResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTransformTextResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TransformTextResponse { + return { text: isSet(object.text) ? globalThis.String(object.text) : "" }; + }, + + toJSON(message: TransformTextResponse): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + return obj; + }, + + create, I>>(base?: I): TransformTextResponse { + return TransformTextResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TransformTextResponse { + const message = createBaseTransformTextResponse(); + message.text = object.text ?? ""; + return message; + }, +}; + +function createBaseNaturalQueryRequest(): NaturalQueryRequest { + return { query: "", context: "" }; +} + +export const NaturalQueryRequest = { + encode(message: NaturalQueryRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.query !== "") { + writer.uint32(10).string(message.query); + } + if (message.context !== "") { + writer.uint32(18).string(message.context); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): NaturalQueryRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseNaturalQueryRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.query = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.context = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): NaturalQueryRequest { + return { + query: isSet(object.query) ? globalThis.String(object.query) : "", + context: isSet(object.context) ? globalThis.String(object.context) : "", + }; + }, + + toJSON(message: NaturalQueryRequest): unknown { + const obj: any = {}; + if (message.query !== "") { + obj.query = message.query; + } + if (message.context !== "") { + obj.context = message.context; + } + return obj; + }, + + create, I>>(base?: I): NaturalQueryRequest { + return NaturalQueryRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): NaturalQueryRequest { + const message = createBaseNaturalQueryRequest(); + message.query = object.query ?? ""; + message.context = object.context ?? ""; + return message; + }, +}; + +function createBaseNaturalQueryResponse(): NaturalQueryResponse { + return { response: "", confidence: 0 }; +} + +export const NaturalQueryResponse = { + encode(message: NaturalQueryResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.response !== "") { + writer.uint32(10).string(message.response); + } + if (message.confidence !== 0) { + writer.uint32(21).float(message.confidence); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): NaturalQueryResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseNaturalQueryResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.response = reader.string(); + continue; + case 2: + if (tag !== 21) { + break; + } + + message.confidence = reader.float(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): NaturalQueryResponse { + return { + response: isSet(object.response) ? globalThis.String(object.response) : "", + confidence: isSet(object.confidence) ? globalThis.Number(object.confidence) : 0, + }; + }, + + toJSON(message: NaturalQueryResponse): unknown { + const obj: any = {}; + if (message.response !== "") { + obj.response = message.response; + } + if (message.confidence !== 0) { + obj.confidence = message.confidence; + } + return obj; + }, + + create, I>>(base?: I): NaturalQueryResponse { + return NaturalQueryResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): NaturalQueryResponse { + const message = createBaseNaturalQueryResponse(); + message.response = object.response ?? ""; + message.confidence = object.confidence ?? 0; + return message; + }, +}; + +export interface RivaNLPService { + Classify(request: ClassifyRequest): Promise; + TokenClassify(request: TokenClassifyRequest): Promise; + AnalyzeEntities(request: AnalyzeEntitiesRequest): Promise; + AnalyzeIntent(request: AnalyzeIntentRequest): Promise; + TransformText(request: TransformTextRequest): Promise; + NaturalQuery(request: NaturalQueryRequest): Promise; +} + +export const RivaNLPServiceServiceName = "nvidia.riva.RivaNLPService"; +export class RivaNLPServiceClientImpl implements RivaNLPService { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || RivaNLPServiceServiceName; + this.rpc = rpc; + this.Classify = this.Classify.bind(this); + this.TokenClassify = this.TokenClassify.bind(this); + this.AnalyzeEntities = this.AnalyzeEntities.bind(this); + this.AnalyzeIntent = this.AnalyzeIntent.bind(this); + this.TransformText = this.TransformText.bind(this); + this.NaturalQuery = this.NaturalQuery.bind(this); + } + Classify(request: ClassifyRequest): Promise { + const data = ClassifyRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "Classify", data); + return promise.then((data) => ClassifyResponse.decode(_m0.Reader.create(data))); + } + + TokenClassify(request: TokenClassifyRequest): Promise { + const data = TokenClassifyRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "TokenClassify", data); + return promise.then((data) => TokenClassifyResponse.decode(_m0.Reader.create(data))); + } + + AnalyzeEntities(request: AnalyzeEntitiesRequest): Promise { + const data = AnalyzeEntitiesRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "AnalyzeEntities", data); + return promise.then((data) => AnalyzeEntitiesResponse.decode(_m0.Reader.create(data))); + } + + AnalyzeIntent(request: AnalyzeIntentRequest): Promise { + const data = AnalyzeIntentRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "AnalyzeIntent", data); + return promise.then((data) => AnalyzeIntentResponse.decode(_m0.Reader.create(data))); + } + + TransformText(request: TransformTextRequest): Promise { + const data = TransformTextRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "TransformText", data); + return promise.then((data) => TransformTextResponse.decode(_m0.Reader.create(data))); + } + + NaturalQuery(request: NaturalQueryRequest): Promise { + const data = NaturalQueryRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "NaturalQuery", data); + return promise.then((data) => NaturalQueryResponse.decode(_m0.Reader.create(data))); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/riva-ts-client/src/proto/riva_nmt.ts b/riva-ts-client/src/proto/riva_nmt.ts new file mode 100644 index 00000000..4905cf4a --- /dev/null +++ b/riva-ts-client/src/proto/riva_nmt.ts @@ -0,0 +1,1005 @@ +// Code generated by protoc-gen-ts_proto. DO NOT EDIT. +// versions: +// protoc-gen-ts_proto v1.181.2 +// protoc v5.29.3 +// source: riva_nmt.proto + +/* eslint-disable */ +import _m0 from "protobufjs/minimal"; +import { Observable } from "rxjs"; +import { map } from "rxjs/operators"; +import { AudioConfig } from "./riva_services"; + +export const protobufPackage = "nvidia.riva"; + +export interface TranslateTextRequest { + text: string; + sourceLanguage: string; + targetLanguage: string; + doNotTranslatePhrases: string[]; +} + +export interface TranslateTextResponse { + text: string; + translations: string[]; +} + +export interface StreamingS2SRequest { + config?: AudioConfig | undefined; + audioContent?: Uint8Array | undefined; +} + +export interface StreamingS2SResponse { + result: StreamingS2SResponse_Result | undefined; +} + +export interface StreamingS2SResponse_Result { + transcript: string; + translation: string; + isPartial: boolean; + audioContent: Uint8Array; +} + +export interface StreamingS2TRequest { + config?: AudioConfig | undefined; + audioContent?: Uint8Array | undefined; +} + +export interface StreamingS2TResponse { + result: StreamingS2TResponse_Result | undefined; +} + +export interface StreamingS2TResponse_Result { + transcript: string; + translation: string; + isPartial: boolean; +} + +export interface AvailableLanguageRequest { + model: string; +} + +export interface AvailableLanguageResponse { + supportedLanguagePairs: AvailableLanguageResponse_LanguagePair[]; +} + +export interface AvailableLanguageResponse_LanguagePair { + sourceLanguageCode: string; + targetLanguageCode: string; +} + +function createBaseTranslateTextRequest(): TranslateTextRequest { + return { text: "", sourceLanguage: "", targetLanguage: "", doNotTranslatePhrases: [] }; +} + +export const TranslateTextRequest = { + encode(message: TranslateTextRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.sourceLanguage !== "") { + writer.uint32(18).string(message.sourceLanguage); + } + if (message.targetLanguage !== "") { + writer.uint32(26).string(message.targetLanguage); + } + for (const v of message.doNotTranslatePhrases) { + writer.uint32(34).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TranslateTextRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTranslateTextRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.sourceLanguage = reader.string(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.targetLanguage = reader.string(); + continue; + case 4: + if (tag !== 34) { + break; + } + + message.doNotTranslatePhrases.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TranslateTextRequest { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + sourceLanguage: isSet(object.sourceLanguage) ? globalThis.String(object.sourceLanguage) : "", + targetLanguage: isSet(object.targetLanguage) ? globalThis.String(object.targetLanguage) : "", + doNotTranslatePhrases: globalThis.Array.isArray(object?.doNotTranslatePhrases) + ? object.doNotTranslatePhrases.map((e: any) => globalThis.String(e)) + : [], + }; + }, + + toJSON(message: TranslateTextRequest): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.sourceLanguage !== "") { + obj.sourceLanguage = message.sourceLanguage; + } + if (message.targetLanguage !== "") { + obj.targetLanguage = message.targetLanguage; + } + if (message.doNotTranslatePhrases?.length) { + obj.doNotTranslatePhrases = message.doNotTranslatePhrases; + } + return obj; + }, + + create, I>>(base?: I): TranslateTextRequest { + return TranslateTextRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TranslateTextRequest { + const message = createBaseTranslateTextRequest(); + message.text = object.text ?? ""; + message.sourceLanguage = object.sourceLanguage ?? ""; + message.targetLanguage = object.targetLanguage ?? ""; + message.doNotTranslatePhrases = object.doNotTranslatePhrases?.map((e) => e) || []; + return message; + }, +}; + +function createBaseTranslateTextResponse(): TranslateTextResponse { + return { text: "", translations: [] }; +} + +export const TranslateTextResponse = { + encode(message: TranslateTextResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + for (const v of message.translations) { + writer.uint32(18).string(v!); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): TranslateTextResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseTranslateTextResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.translations.push(reader.string()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): TranslateTextResponse { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + translations: globalThis.Array.isArray(object?.translations) + ? object.translations.map((e: any) => globalThis.String(e)) + : [], + }; + }, + + toJSON(message: TranslateTextResponse): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.translations?.length) { + obj.translations = message.translations; + } + return obj; + }, + + create, I>>(base?: I): TranslateTextResponse { + return TranslateTextResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): TranslateTextResponse { + const message = createBaseTranslateTextResponse(); + message.text = object.text ?? ""; + message.translations = object.translations?.map((e) => e) || []; + return message; + }, +}; + +function createBaseStreamingS2SRequest(): StreamingS2SRequest { + return { config: undefined, audioContent: undefined }; +} + +export const StreamingS2SRequest = { + encode(message: StreamingS2SRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.config !== undefined) { + AudioConfig.encode(message.config, writer.uint32(10).fork()).ldelim(); + } + if (message.audioContent !== undefined) { + writer.uint32(18).bytes(message.audioContent); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2SRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2SRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.config = AudioConfig.decode(reader, reader.uint32()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.audioContent = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2SRequest { + return { + config: isSet(object.config) ? AudioConfig.fromJSON(object.config) : undefined, + audioContent: isSet(object.audioContent) ? bytesFromBase64(object.audioContent) : undefined, + }; + }, + + toJSON(message: StreamingS2SRequest): unknown { + const obj: any = {}; + if (message.config !== undefined) { + obj.config = AudioConfig.toJSON(message.config); + } + if (message.audioContent !== undefined) { + obj.audioContent = base64FromBytes(message.audioContent); + } + return obj; + }, + + create, I>>(base?: I): StreamingS2SRequest { + return StreamingS2SRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2SRequest { + const message = createBaseStreamingS2SRequest(); + message.config = (object.config !== undefined && object.config !== null) + ? AudioConfig.fromPartial(object.config) + : undefined; + message.audioContent = object.audioContent ?? undefined; + return message; + }, +}; + +function createBaseStreamingS2SResponse(): StreamingS2SResponse { + return { result: undefined }; +} + +export const StreamingS2SResponse = { + encode(message: StreamingS2SResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.result !== undefined) { + StreamingS2SResponse_Result.encode(message.result, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2SResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2SResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.result = StreamingS2SResponse_Result.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2SResponse { + return { result: isSet(object.result) ? StreamingS2SResponse_Result.fromJSON(object.result) : undefined }; + }, + + toJSON(message: StreamingS2SResponse): unknown { + const obj: any = {}; + if (message.result !== undefined) { + obj.result = StreamingS2SResponse_Result.toJSON(message.result); + } + return obj; + }, + + create, I>>(base?: I): StreamingS2SResponse { + return StreamingS2SResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2SResponse { + const message = createBaseStreamingS2SResponse(); + message.result = (object.result !== undefined && object.result !== null) + ? StreamingS2SResponse_Result.fromPartial(object.result) + : undefined; + return message; + }, +}; + +function createBaseStreamingS2SResponse_Result(): StreamingS2SResponse_Result { + return { transcript: "", translation: "", isPartial: false, audioContent: new Uint8Array(0) }; +} + +export const StreamingS2SResponse_Result = { + encode(message: StreamingS2SResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.transcript !== "") { + writer.uint32(10).string(message.transcript); + } + if (message.translation !== "") { + writer.uint32(18).string(message.translation); + } + if (message.isPartial !== false) { + writer.uint32(24).bool(message.isPartial); + } + if (message.audioContent.length !== 0) { + writer.uint32(34).bytes(message.audioContent); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2SResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2SResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.transcript = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.translation = reader.string(); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.isPartial = reader.bool(); + continue; + case 4: + if (tag !== 34) { + break; + } + + message.audioContent = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2SResponse_Result { + return { + transcript: isSet(object.transcript) ? globalThis.String(object.transcript) : "", + translation: isSet(object.translation) ? globalThis.String(object.translation) : "", + isPartial: isSet(object.isPartial) ? globalThis.Boolean(object.isPartial) : false, + audioContent: isSet(object.audioContent) ? bytesFromBase64(object.audioContent) : new Uint8Array(0), + }; + }, + + toJSON(message: StreamingS2SResponse_Result): unknown { + const obj: any = {}; + if (message.transcript !== "") { + obj.transcript = message.transcript; + } + if (message.translation !== "") { + obj.translation = message.translation; + } + if (message.isPartial !== false) { + obj.isPartial = message.isPartial; + } + if (message.audioContent.length !== 0) { + obj.audioContent = base64FromBytes(message.audioContent); + } + return obj; + }, + + create, I>>(base?: I): StreamingS2SResponse_Result { + return StreamingS2SResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2SResponse_Result { + const message = createBaseStreamingS2SResponse_Result(); + message.transcript = object.transcript ?? ""; + message.translation = object.translation ?? ""; + message.isPartial = object.isPartial ?? false; + message.audioContent = object.audioContent ?? new Uint8Array(0); + return message; + }, +}; + +function createBaseStreamingS2TRequest(): StreamingS2TRequest { + return { config: undefined, audioContent: undefined }; +} + +export const StreamingS2TRequest = { + encode(message: StreamingS2TRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.config !== undefined) { + AudioConfig.encode(message.config, writer.uint32(10).fork()).ldelim(); + } + if (message.audioContent !== undefined) { + writer.uint32(18).bytes(message.audioContent); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2TRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2TRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.config = AudioConfig.decode(reader, reader.uint32()); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.audioContent = reader.bytes(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2TRequest { + return { + config: isSet(object.config) ? AudioConfig.fromJSON(object.config) : undefined, + audioContent: isSet(object.audioContent) ? bytesFromBase64(object.audioContent) : undefined, + }; + }, + + toJSON(message: StreamingS2TRequest): unknown { + const obj: any = {}; + if (message.config !== undefined) { + obj.config = AudioConfig.toJSON(message.config); + } + if (message.audioContent !== undefined) { + obj.audioContent = base64FromBytes(message.audioContent); + } + return obj; + }, + + create, I>>(base?: I): StreamingS2TRequest { + return StreamingS2TRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2TRequest { + const message = createBaseStreamingS2TRequest(); + message.config = (object.config !== undefined && object.config !== null) + ? AudioConfig.fromPartial(object.config) + : undefined; + message.audioContent = object.audioContent ?? undefined; + return message; + }, +}; + +function createBaseStreamingS2TResponse(): StreamingS2TResponse { + return { result: undefined }; +} + +export const StreamingS2TResponse = { + encode(message: StreamingS2TResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.result !== undefined) { + StreamingS2TResponse_Result.encode(message.result, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2TResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2TResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.result = StreamingS2TResponse_Result.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2TResponse { + return { result: isSet(object.result) ? StreamingS2TResponse_Result.fromJSON(object.result) : undefined }; + }, + + toJSON(message: StreamingS2TResponse): unknown { + const obj: any = {}; + if (message.result !== undefined) { + obj.result = StreamingS2TResponse_Result.toJSON(message.result); + } + return obj; + }, + + create, I>>(base?: I): StreamingS2TResponse { + return StreamingS2TResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2TResponse { + const message = createBaseStreamingS2TResponse(); + message.result = (object.result !== undefined && object.result !== null) + ? StreamingS2TResponse_Result.fromPartial(object.result) + : undefined; + return message; + }, +}; + +function createBaseStreamingS2TResponse_Result(): StreamingS2TResponse_Result { + return { transcript: "", translation: "", isPartial: false }; +} + +export const StreamingS2TResponse_Result = { + encode(message: StreamingS2TResponse_Result, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.transcript !== "") { + writer.uint32(10).string(message.transcript); + } + if (message.translation !== "") { + writer.uint32(18).string(message.translation); + } + if (message.isPartial !== false) { + writer.uint32(24).bool(message.isPartial); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): StreamingS2TResponse_Result { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseStreamingS2TResponse_Result(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.transcript = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.translation = reader.string(); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.isPartial = reader.bool(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): StreamingS2TResponse_Result { + return { + transcript: isSet(object.transcript) ? globalThis.String(object.transcript) : "", + translation: isSet(object.translation) ? globalThis.String(object.translation) : "", + isPartial: isSet(object.isPartial) ? globalThis.Boolean(object.isPartial) : false, + }; + }, + + toJSON(message: StreamingS2TResponse_Result): unknown { + const obj: any = {}; + if (message.transcript !== "") { + obj.transcript = message.transcript; + } + if (message.translation !== "") { + obj.translation = message.translation; + } + if (message.isPartial !== false) { + obj.isPartial = message.isPartial; + } + return obj; + }, + + create, I>>(base?: I): StreamingS2TResponse_Result { + return StreamingS2TResponse_Result.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): StreamingS2TResponse_Result { + const message = createBaseStreamingS2TResponse_Result(); + message.transcript = object.transcript ?? ""; + message.translation = object.translation ?? ""; + message.isPartial = object.isPartial ?? false; + return message; + }, +}; + +function createBaseAvailableLanguageRequest(): AvailableLanguageRequest { + return { model: "" }; +} + +export const AvailableLanguageRequest = { + encode(message: AvailableLanguageRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.model !== "") { + writer.uint32(10).string(message.model); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AvailableLanguageRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAvailableLanguageRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.model = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AvailableLanguageRequest { + return { model: isSet(object.model) ? globalThis.String(object.model) : "" }; + }, + + toJSON(message: AvailableLanguageRequest): unknown { + const obj: any = {}; + if (message.model !== "") { + obj.model = message.model; + } + return obj; + }, + + create, I>>(base?: I): AvailableLanguageRequest { + return AvailableLanguageRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AvailableLanguageRequest { + const message = createBaseAvailableLanguageRequest(); + message.model = object.model ?? ""; + return message; + }, +}; + +function createBaseAvailableLanguageResponse(): AvailableLanguageResponse { + return { supportedLanguagePairs: [] }; +} + +export const AvailableLanguageResponse = { + encode(message: AvailableLanguageResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.supportedLanguagePairs) { + AvailableLanguageResponse_LanguagePair.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AvailableLanguageResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAvailableLanguageResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.supportedLanguagePairs.push(AvailableLanguageResponse_LanguagePair.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AvailableLanguageResponse { + return { + supportedLanguagePairs: globalThis.Array.isArray(object?.supportedLanguagePairs) + ? object.supportedLanguagePairs.map((e: any) => AvailableLanguageResponse_LanguagePair.fromJSON(e)) + : [], + }; + }, + + toJSON(message: AvailableLanguageResponse): unknown { + const obj: any = {}; + if (message.supportedLanguagePairs?.length) { + obj.supportedLanguagePairs = message.supportedLanguagePairs.map((e) => + AvailableLanguageResponse_LanguagePair.toJSON(e) + ); + } + return obj; + }, + + create, I>>(base?: I): AvailableLanguageResponse { + return AvailableLanguageResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AvailableLanguageResponse { + const message = createBaseAvailableLanguageResponse(); + message.supportedLanguagePairs = + object.supportedLanguagePairs?.map((e) => AvailableLanguageResponse_LanguagePair.fromPartial(e)) || []; + return message; + }, +}; + +function createBaseAvailableLanguageResponse_LanguagePair(): AvailableLanguageResponse_LanguagePair { + return { sourceLanguageCode: "", targetLanguageCode: "" }; +} + +export const AvailableLanguageResponse_LanguagePair = { + encode(message: AvailableLanguageResponse_LanguagePair, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.sourceLanguageCode !== "") { + writer.uint32(10).string(message.sourceLanguageCode); + } + if (message.targetLanguageCode !== "") { + writer.uint32(18).string(message.targetLanguageCode); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AvailableLanguageResponse_LanguagePair { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAvailableLanguageResponse_LanguagePair(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.sourceLanguageCode = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.targetLanguageCode = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AvailableLanguageResponse_LanguagePair { + return { + sourceLanguageCode: isSet(object.sourceLanguageCode) ? globalThis.String(object.sourceLanguageCode) : "", + targetLanguageCode: isSet(object.targetLanguageCode) ? globalThis.String(object.targetLanguageCode) : "", + }; + }, + + toJSON(message: AvailableLanguageResponse_LanguagePair): unknown { + const obj: any = {}; + if (message.sourceLanguageCode !== "") { + obj.sourceLanguageCode = message.sourceLanguageCode; + } + if (message.targetLanguageCode !== "") { + obj.targetLanguageCode = message.targetLanguageCode; + } + return obj; + }, + + create, I>>( + base?: I, + ): AvailableLanguageResponse_LanguagePair { + return AvailableLanguageResponse_LanguagePair.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): AvailableLanguageResponse_LanguagePair { + const message = createBaseAvailableLanguageResponse_LanguagePair(); + message.sourceLanguageCode = object.sourceLanguageCode ?? ""; + message.targetLanguageCode = object.targetLanguageCode ?? ""; + return message; + }, +}; + +export interface RivaNMTService { + TranslateText(request: TranslateTextRequest): Promise; + StreamingTranslateSpeechToSpeech(request: Observable): Observable; + StreamingTranslateSpeechToText(request: Observable): Observable; + ListSupportedLanguagePairs(request: AvailableLanguageRequest): Promise; +} + +export const RivaNMTServiceServiceName = "nvidia.riva.RivaNMTService"; +export class RivaNMTServiceClientImpl implements RivaNMTService { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || RivaNMTServiceServiceName; + this.rpc = rpc; + this.TranslateText = this.TranslateText.bind(this); + this.StreamingTranslateSpeechToSpeech = this.StreamingTranslateSpeechToSpeech.bind(this); + this.StreamingTranslateSpeechToText = this.StreamingTranslateSpeechToText.bind(this); + this.ListSupportedLanguagePairs = this.ListSupportedLanguagePairs.bind(this); + } + TranslateText(request: TranslateTextRequest): Promise { + const data = TranslateTextRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "TranslateText", data); + return promise.then((data) => TranslateTextResponse.decode(_m0.Reader.create(data))); + } + + StreamingTranslateSpeechToSpeech(request: Observable): Observable { + const data = request.pipe(map((request) => StreamingS2SRequest.encode(request).finish())); + const result = this.rpc.bidirectionalStreamingRequest(this.service, "StreamingTranslateSpeechToSpeech", data); + return result.pipe(map((data) => StreamingS2SResponse.decode(_m0.Reader.create(data)))); + } + + StreamingTranslateSpeechToText(request: Observable): Observable { + const data = request.pipe(map((request) => StreamingS2TRequest.encode(request).finish())); + const result = this.rpc.bidirectionalStreamingRequest(this.service, "StreamingTranslateSpeechToText", data); + return result.pipe(map((data) => StreamingS2TResponse.decode(_m0.Reader.create(data)))); + } + + ListSupportedLanguagePairs(request: AvailableLanguageRequest): Promise { + const data = AvailableLanguageRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "ListSupportedLanguagePairs", data); + return promise.then((data) => AvailableLanguageResponse.decode(_m0.Reader.create(data))); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + clientStreamingRequest(service: string, method: string, data: Observable): Promise; + serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable; + bidirectionalStreamingRequest(service: string, method: string, data: Observable): Observable; +} + +function bytesFromBase64(b64: string): Uint8Array { + if ((globalThis as any).Buffer) { + return Uint8Array.from(globalThis.Buffer.from(b64, "base64")); + } else { + const bin = globalThis.atob(b64); + const arr = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; ++i) { + arr[i] = bin.charCodeAt(i); + } + return arr; + } +} + +function base64FromBytes(arr: Uint8Array): string { + if ((globalThis as any).Buffer) { + return globalThis.Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(globalThis.String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/riva-ts-client/src/proto/riva_services.ts b/riva-ts-client/src/proto/riva_services.ts new file mode 100644 index 00000000..83752f56 --- /dev/null +++ b/riva-ts-client/src/proto/riva_services.ts @@ -0,0 +1,791 @@ +// Code generated by protoc-gen-ts_proto. DO NOT EDIT. +// versions: +// protoc-gen-ts_proto v1.181.2 +// protoc v5.29.3 +// source: riva_services.proto + +/* eslint-disable */ +import _m0 from "protobufjs/minimal"; +import { Observable } from "rxjs"; +import { map } from "rxjs/operators"; + +export const protobufPackage = "nvidia.riva"; + +export enum AudioEncoding { + ENCODING_UNSPECIFIED = 0, + LINEAR_PCM = 1, + FLAC = 2, + MULAW = 3, + ALAW = 4, + UNRECOGNIZED = -1, +} + +export function audioEncodingFromJSON(object: any): AudioEncoding { + switch (object) { + case 0: + case "ENCODING_UNSPECIFIED": + return AudioEncoding.ENCODING_UNSPECIFIED; + case 1: + case "LINEAR_PCM": + return AudioEncoding.LINEAR_PCM; + case 2: + case "FLAC": + return AudioEncoding.FLAC; + case 3: + case "MULAW": + return AudioEncoding.MULAW; + case 4: + case "ALAW": + return AudioEncoding.ALAW; + case -1: + case "UNRECOGNIZED": + default: + return AudioEncoding.UNRECOGNIZED; + } +} + +export function audioEncodingToJSON(object: AudioEncoding): string { + switch (object) { + case AudioEncoding.ENCODING_UNSPECIFIED: + return "ENCODING_UNSPECIFIED"; + case AudioEncoding.LINEAR_PCM: + return "LINEAR_PCM"; + case AudioEncoding.FLAC: + return "FLAC"; + case AudioEncoding.MULAW: + return "MULAW"; + case AudioEncoding.ALAW: + return "ALAW"; + case AudioEncoding.UNRECOGNIZED: + default: + return "UNRECOGNIZED"; + } +} + +/** Message types for Speech Synthesis */ +export interface SynthesizeRequest { + text: string; + languageCode: string; + sampleRateHz: number; + encoding: AudioEncoding; + voiceName: string; + customDictionary?: string | undefined; +} + +export interface SynthesizeResponse { + audio: Uint8Array; + audioConfig: AudioConfig | undefined; +} + +export interface GetRivaSynthesisConfigRequest { +} + +export interface GetRivaSynthesisConfigResponse { + modelConfig: GetRivaSynthesisConfigResponse_ModelConfig[]; +} + +export interface GetRivaSynthesisConfigResponse_ModelConfig { + parameters: GetRivaSynthesisConfigResponse_ModelConfig_Parameters | undefined; +} + +export interface GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + languageCode: string; + voiceName: string; + subvoices: string; +} + +/** Common types */ +export interface AudioConfig { + encoding: AudioEncoding; + sampleRateHz: number; + languageCode: string; + enableWordTimeOffsets: boolean; + channels: number; +} + +function createBaseSynthesizeRequest(): SynthesizeRequest { + return { text: "", languageCode: "", sampleRateHz: 0, encoding: 0, voiceName: "", customDictionary: undefined }; +} + +export const SynthesizeRequest = { + encode(message: SynthesizeRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.text !== "") { + writer.uint32(10).string(message.text); + } + if (message.languageCode !== "") { + writer.uint32(18).string(message.languageCode); + } + if (message.sampleRateHz !== 0) { + writer.uint32(24).int32(message.sampleRateHz); + } + if (message.encoding !== 0) { + writer.uint32(32).int32(message.encoding); + } + if (message.voiceName !== "") { + writer.uint32(42).string(message.voiceName); + } + if (message.customDictionary !== undefined) { + writer.uint32(50).string(message.customDictionary); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): SynthesizeRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSynthesizeRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.text = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.languageCode = reader.string(); + continue; + case 3: + if (tag !== 24) { + break; + } + + message.sampleRateHz = reader.int32(); + continue; + case 4: + if (tag !== 32) { + break; + } + + message.encoding = reader.int32() as any; + continue; + case 5: + if (tag !== 42) { + break; + } + + message.voiceName = reader.string(); + continue; + case 6: + if (tag !== 50) { + break; + } + + message.customDictionary = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): SynthesizeRequest { + return { + text: isSet(object.text) ? globalThis.String(object.text) : "", + languageCode: isSet(object.languageCode) ? globalThis.String(object.languageCode) : "", + sampleRateHz: isSet(object.sampleRateHz) ? globalThis.Number(object.sampleRateHz) : 0, + encoding: isSet(object.encoding) ? audioEncodingFromJSON(object.encoding) : 0, + voiceName: isSet(object.voiceName) ? globalThis.String(object.voiceName) : "", + customDictionary: isSet(object.customDictionary) ? globalThis.String(object.customDictionary) : undefined, + }; + }, + + toJSON(message: SynthesizeRequest): unknown { + const obj: any = {}; + if (message.text !== "") { + obj.text = message.text; + } + if (message.languageCode !== "") { + obj.languageCode = message.languageCode; + } + if (message.sampleRateHz !== 0) { + obj.sampleRateHz = Math.round(message.sampleRateHz); + } + if (message.encoding !== 0) { + obj.encoding = audioEncodingToJSON(message.encoding); + } + if (message.voiceName !== "") { + obj.voiceName = message.voiceName; + } + if (message.customDictionary !== undefined) { + obj.customDictionary = message.customDictionary; + } + return obj; + }, + + create, I>>(base?: I): SynthesizeRequest { + return SynthesizeRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): SynthesizeRequest { + const message = createBaseSynthesizeRequest(); + message.text = object.text ?? ""; + message.languageCode = object.languageCode ?? ""; + message.sampleRateHz = object.sampleRateHz ?? 0; + message.encoding = object.encoding ?? 0; + message.voiceName = object.voiceName ?? ""; + message.customDictionary = object.customDictionary ?? undefined; + return message; + }, +}; + +function createBaseSynthesizeResponse(): SynthesizeResponse { + return { audio: new Uint8Array(0), audioConfig: undefined }; +} + +export const SynthesizeResponse = { + encode(message: SynthesizeResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.audio.length !== 0) { + writer.uint32(10).bytes(message.audio); + } + if (message.audioConfig !== undefined) { + AudioConfig.encode(message.audioConfig, writer.uint32(18).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): SynthesizeResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseSynthesizeResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.audio = reader.bytes(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.audioConfig = AudioConfig.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): SynthesizeResponse { + return { + audio: isSet(object.audio) ? bytesFromBase64(object.audio) : new Uint8Array(0), + audioConfig: isSet(object.audioConfig) ? AudioConfig.fromJSON(object.audioConfig) : undefined, + }; + }, + + toJSON(message: SynthesizeResponse): unknown { + const obj: any = {}; + if (message.audio.length !== 0) { + obj.audio = base64FromBytes(message.audio); + } + if (message.audioConfig !== undefined) { + obj.audioConfig = AudioConfig.toJSON(message.audioConfig); + } + return obj; + }, + + create, I>>(base?: I): SynthesizeResponse { + return SynthesizeResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): SynthesizeResponse { + const message = createBaseSynthesizeResponse(); + message.audio = object.audio ?? new Uint8Array(0); + message.audioConfig = (object.audioConfig !== undefined && object.audioConfig !== null) + ? AudioConfig.fromPartial(object.audioConfig) + : undefined; + return message; + }, +}; + +function createBaseGetRivaSynthesisConfigRequest(): GetRivaSynthesisConfigRequest { + return {}; +} + +export const GetRivaSynthesisConfigRequest = { + encode(_: GetRivaSynthesisConfigRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetRivaSynthesisConfigRequest { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetRivaSynthesisConfigRequest(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(_: any): GetRivaSynthesisConfigRequest { + return {}; + }, + + toJSON(_: GetRivaSynthesisConfigRequest): unknown { + const obj: any = {}; + return obj; + }, + + create, I>>(base?: I): GetRivaSynthesisConfigRequest { + return GetRivaSynthesisConfigRequest.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(_: I): GetRivaSynthesisConfigRequest { + const message = createBaseGetRivaSynthesisConfigRequest(); + return message; + }, +}; + +function createBaseGetRivaSynthesisConfigResponse(): GetRivaSynthesisConfigResponse { + return { modelConfig: [] }; +} + +export const GetRivaSynthesisConfigResponse = { + encode(message: GetRivaSynthesisConfigResponse, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + for (const v of message.modelConfig) { + GetRivaSynthesisConfigResponse_ModelConfig.encode(v!, writer.uint32(10).fork()).ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetRivaSynthesisConfigResponse { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetRivaSynthesisConfigResponse(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.modelConfig.push(GetRivaSynthesisConfigResponse_ModelConfig.decode(reader, reader.uint32())); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetRivaSynthesisConfigResponse { + return { + modelConfig: globalThis.Array.isArray(object?.modelConfig) + ? object.modelConfig.map((e: any) => GetRivaSynthesisConfigResponse_ModelConfig.fromJSON(e)) + : [], + }; + }, + + toJSON(message: GetRivaSynthesisConfigResponse): unknown { + const obj: any = {}; + if (message.modelConfig?.length) { + obj.modelConfig = message.modelConfig.map((e) => GetRivaSynthesisConfigResponse_ModelConfig.toJSON(e)); + } + return obj; + }, + + create, I>>(base?: I): GetRivaSynthesisConfigResponse { + return GetRivaSynthesisConfigResponse.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): GetRivaSynthesisConfigResponse { + const message = createBaseGetRivaSynthesisConfigResponse(); + message.modelConfig = object.modelConfig?.map((e) => GetRivaSynthesisConfigResponse_ModelConfig.fromPartial(e)) || + []; + return message; + }, +}; + +function createBaseGetRivaSynthesisConfigResponse_ModelConfig(): GetRivaSynthesisConfigResponse_ModelConfig { + return { parameters: undefined }; +} + +export const GetRivaSynthesisConfigResponse_ModelConfig = { + encode(message: GetRivaSynthesisConfigResponse_ModelConfig, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.parameters !== undefined) { + GetRivaSynthesisConfigResponse_ModelConfig_Parameters.encode(message.parameters, writer.uint32(10).fork()) + .ldelim(); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetRivaSynthesisConfigResponse_ModelConfig { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetRivaSynthesisConfigResponse_ModelConfig(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.parameters = GetRivaSynthesisConfigResponse_ModelConfig_Parameters.decode(reader, reader.uint32()); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetRivaSynthesisConfigResponse_ModelConfig { + return { + parameters: isSet(object.parameters) + ? GetRivaSynthesisConfigResponse_ModelConfig_Parameters.fromJSON(object.parameters) + : undefined, + }; + }, + + toJSON(message: GetRivaSynthesisConfigResponse_ModelConfig): unknown { + const obj: any = {}; + if (message.parameters !== undefined) { + obj.parameters = GetRivaSynthesisConfigResponse_ModelConfig_Parameters.toJSON(message.parameters); + } + return obj; + }, + + create, I>>( + base?: I, + ): GetRivaSynthesisConfigResponse_ModelConfig { + return GetRivaSynthesisConfigResponse_ModelConfig.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): GetRivaSynthesisConfigResponse_ModelConfig { + const message = createBaseGetRivaSynthesisConfigResponse_ModelConfig(); + message.parameters = (object.parameters !== undefined && object.parameters !== null) + ? GetRivaSynthesisConfigResponse_ModelConfig_Parameters.fromPartial(object.parameters) + : undefined; + return message; + }, +}; + +function createBaseGetRivaSynthesisConfigResponse_ModelConfig_Parameters(): GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + return { languageCode: "", voiceName: "", subvoices: "" }; +} + +export const GetRivaSynthesisConfigResponse_ModelConfig_Parameters = { + encode( + message: GetRivaSynthesisConfigResponse_ModelConfig_Parameters, + writer: _m0.Writer = _m0.Writer.create(), + ): _m0.Writer { + if (message.languageCode !== "") { + writer.uint32(10).string(message.languageCode); + } + if (message.voiceName !== "") { + writer.uint32(18).string(message.voiceName); + } + if (message.subvoices !== "") { + writer.uint32(26).string(message.subvoices); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseGetRivaSynthesisConfigResponse_ModelConfig_Parameters(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 10) { + break; + } + + message.languageCode = reader.string(); + continue; + case 2: + if (tag !== 18) { + break; + } + + message.voiceName = reader.string(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.subvoices = reader.string(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + return { + languageCode: isSet(object.languageCode) ? globalThis.String(object.languageCode) : "", + voiceName: isSet(object.voiceName) ? globalThis.String(object.voiceName) : "", + subvoices: isSet(object.subvoices) ? globalThis.String(object.subvoices) : "", + }; + }, + + toJSON(message: GetRivaSynthesisConfigResponse_ModelConfig_Parameters): unknown { + const obj: any = {}; + if (message.languageCode !== "") { + obj.languageCode = message.languageCode; + } + if (message.voiceName !== "") { + obj.voiceName = message.voiceName; + } + if (message.subvoices !== "") { + obj.subvoices = message.subvoices; + } + return obj; + }, + + create, I>>( + base?: I, + ): GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + return GetRivaSynthesisConfigResponse_ModelConfig_Parameters.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>( + object: I, + ): GetRivaSynthesisConfigResponse_ModelConfig_Parameters { + const message = createBaseGetRivaSynthesisConfigResponse_ModelConfig_Parameters(); + message.languageCode = object.languageCode ?? ""; + message.voiceName = object.voiceName ?? ""; + message.subvoices = object.subvoices ?? ""; + return message; + }, +}; + +function createBaseAudioConfig(): AudioConfig { + return { encoding: 0, sampleRateHz: 0, languageCode: "", enableWordTimeOffsets: false, channels: 0 }; +} + +export const AudioConfig = { + encode(message: AudioConfig, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer { + if (message.encoding !== 0) { + writer.uint32(8).int32(message.encoding); + } + if (message.sampleRateHz !== 0) { + writer.uint32(16).int32(message.sampleRateHz); + } + if (message.languageCode !== "") { + writer.uint32(26).string(message.languageCode); + } + if (message.enableWordTimeOffsets !== false) { + writer.uint32(32).bool(message.enableWordTimeOffsets); + } + if (message.channels !== 0) { + writer.uint32(40).int32(message.channels); + } + return writer; + }, + + decode(input: _m0.Reader | Uint8Array, length?: number): AudioConfig { + const reader = input instanceof _m0.Reader ? input : _m0.Reader.create(input); + let end = length === undefined ? reader.len : reader.pos + length; + const message = createBaseAudioConfig(); + while (reader.pos < end) { + const tag = reader.uint32(); + switch (tag >>> 3) { + case 1: + if (tag !== 8) { + break; + } + + message.encoding = reader.int32() as any; + continue; + case 2: + if (tag !== 16) { + break; + } + + message.sampleRateHz = reader.int32(); + continue; + case 3: + if (tag !== 26) { + break; + } + + message.languageCode = reader.string(); + continue; + case 4: + if (tag !== 32) { + break; + } + + message.enableWordTimeOffsets = reader.bool(); + continue; + case 5: + if (tag !== 40) { + break; + } + + message.channels = reader.int32(); + continue; + } + if ((tag & 7) === 4 || tag === 0) { + break; + } + reader.skipType(tag & 7); + } + return message; + }, + + fromJSON(object: any): AudioConfig { + return { + encoding: isSet(object.encoding) ? audioEncodingFromJSON(object.encoding) : 0, + sampleRateHz: isSet(object.sampleRateHz) ? globalThis.Number(object.sampleRateHz) : 0, + languageCode: isSet(object.languageCode) ? globalThis.String(object.languageCode) : "", + enableWordTimeOffsets: isSet(object.enableWordTimeOffsets) + ? globalThis.Boolean(object.enableWordTimeOffsets) + : false, + channels: isSet(object.channels) ? globalThis.Number(object.channels) : 0, + }; + }, + + toJSON(message: AudioConfig): unknown { + const obj: any = {}; + if (message.encoding !== 0) { + obj.encoding = audioEncodingToJSON(message.encoding); + } + if (message.sampleRateHz !== 0) { + obj.sampleRateHz = Math.round(message.sampleRateHz); + } + if (message.languageCode !== "") { + obj.languageCode = message.languageCode; + } + if (message.enableWordTimeOffsets !== false) { + obj.enableWordTimeOffsets = message.enableWordTimeOffsets; + } + if (message.channels !== 0) { + obj.channels = Math.round(message.channels); + } + return obj; + }, + + create, I>>(base?: I): AudioConfig { + return AudioConfig.fromPartial(base ?? ({} as any)); + }, + fromPartial, I>>(object: I): AudioConfig { + const message = createBaseAudioConfig(); + message.encoding = object.encoding ?? 0; + message.sampleRateHz = object.sampleRateHz ?? 0; + message.languageCode = object.languageCode ?? ""; + message.enableWordTimeOffsets = object.enableWordTimeOffsets ?? false; + message.channels = object.channels ?? 0; + return message; + }, +}; + +/** Service definitions */ +export interface RivaSpeechSynthesis { + Synthesize(request: SynthesizeRequest): Promise; + SynthesizeOnline(request: SynthesizeRequest): Observable; + GetRivaSynthesisConfig(request: GetRivaSynthesisConfigRequest): Promise; +} + +export const RivaSpeechSynthesisServiceName = "nvidia.riva.RivaSpeechSynthesis"; +export class RivaSpeechSynthesisClientImpl implements RivaSpeechSynthesis { + private readonly rpc: Rpc; + private readonly service: string; + constructor(rpc: Rpc, opts?: { service?: string }) { + this.service = opts?.service || RivaSpeechSynthesisServiceName; + this.rpc = rpc; + this.Synthesize = this.Synthesize.bind(this); + this.SynthesizeOnline = this.SynthesizeOnline.bind(this); + this.GetRivaSynthesisConfig = this.GetRivaSynthesisConfig.bind(this); + } + Synthesize(request: SynthesizeRequest): Promise { + const data = SynthesizeRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "Synthesize", data); + return promise.then((data) => SynthesizeResponse.decode(_m0.Reader.create(data))); + } + + SynthesizeOnline(request: SynthesizeRequest): Observable { + const data = SynthesizeRequest.encode(request).finish(); + const result = this.rpc.serverStreamingRequest(this.service, "SynthesizeOnline", data); + return result.pipe(map((data) => SynthesizeResponse.decode(_m0.Reader.create(data)))); + } + + GetRivaSynthesisConfig(request: GetRivaSynthesisConfigRequest): Promise { + const data = GetRivaSynthesisConfigRequest.encode(request).finish(); + const promise = this.rpc.request(this.service, "GetRivaSynthesisConfig", data); + return promise.then((data) => GetRivaSynthesisConfigResponse.decode(_m0.Reader.create(data))); + } +} + +interface Rpc { + request(service: string, method: string, data: Uint8Array): Promise; + clientStreamingRequest(service: string, method: string, data: Observable): Promise; + serverStreamingRequest(service: string, method: string, data: Uint8Array): Observable; + bidirectionalStreamingRequest(service: string, method: string, data: Observable): Observable; +} + +function bytesFromBase64(b64: string): Uint8Array { + if ((globalThis as any).Buffer) { + return Uint8Array.from(globalThis.Buffer.from(b64, "base64")); + } else { + const bin = globalThis.atob(b64); + const arr = new Uint8Array(bin.length); + for (let i = 0; i < bin.length; ++i) { + arr[i] = bin.charCodeAt(i); + } + return arr; + } +} + +function base64FromBytes(arr: Uint8Array): string { + if ((globalThis as any).Buffer) { + return globalThis.Buffer.from(arr).toString("base64"); + } else { + const bin: string[] = []; + arr.forEach((byte) => { + bin.push(globalThis.String.fromCharCode(byte)); + }); + return globalThis.btoa(bin.join("")); + } +} + +type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + +export type DeepPartial = T extends Builtin ? T + : T extends globalThis.Array ? globalThis.Array> + : T extends ReadonlyArray ? ReadonlyArray> + : T extends {} ? { [K in keyof T]?: DeepPartial } + : Partial; + +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin ? P + : P & { [K in keyof P]: Exact } & { [K in Exclude>]: never }; + +function isSet(value: any): boolean { + return value !== null && value !== undefined; +} diff --git a/riva-ts-client/src/types/mic.d.ts b/riva-ts-client/src/types/mic.d.ts new file mode 100644 index 00000000..4afad784 --- /dev/null +++ b/riva-ts-client/src/types/mic.d.ts @@ -0,0 +1,24 @@ +declare module 'mic' { + interface MicOptions { + rate?: string; + channels?: string; + debug?: boolean; + exitOnSilence?: number; + device?: string; + endian?: 'big' | 'little'; + bitwidth?: string; + encoding?: string; + additionalParameters?: string[]; + } + + interface MicInstance { + start(): void; + stop(): void; + pause(): void; + resume(): void; + getAudioStream(): NodeJS.ReadableStream; + } + + function mic(options?: MicOptions): MicInstance; + export = mic; +} diff --git a/riva-ts-client/tests/unit/asr.test.ts b/riva-ts-client/tests/unit/asr.test.ts new file mode 100644 index 00000000..84961fe1 --- /dev/null +++ b/riva-ts-client/tests/unit/asr.test.ts @@ -0,0 +1,223 @@ +import { vi, describe, it, expect, beforeEach, type MockInstance } from 'vitest'; +import { ASRService } from '../../src/client/asr'; +import { createGrpcMock, createMetadataMock } from './helpers/grpc'; +import { createAudioMocks } from './helpers/audio'; +import { createMockStream } from './helpers/stream'; +import { AudioEncoding, type ASRServiceClient, type RecognitionConfig, type StreamingRecognitionConfig, type StreamingRecognizeResponse } from '../../src/client/asr/types'; +import * as grpc from '@grpc/grpc-js'; + +// Create mock client before setting up mocks +const mockClient = createGrpcMock(['recognize', 'streamingRecognize', 'listModels']); + +vi.mock('@grpc/grpc-js', async () => { + const actual = await vi.importActual('@grpc/grpc-js') as typeof grpc; + return { + ...actual, + credentials: { + createInsecure: vi.fn(), + createFromMetadataGenerator: vi.fn() + }, + Metadata: vi.fn(), + Channel: vi.fn().mockImplementation(() => ({ + getTarget: vi.fn(), + close: vi.fn(), + getConnectivityState: vi.fn(), + watchConnectivityState: vi.fn() + })) + }; +}); + +vi.mock('../../src/client/utils/proto', () => ({ + getProtoClient: () => ({ + RivaSpeechRecognitionClient: function() { + return mockClient; + } + }) +})); + +describe('ASRService', () => { + let service: ASRService; + let audioMocks; + let mockMetadata; + + beforeEach(() => { + audioMocks = createAudioMocks(); + mockMetadata = createMetadataMock(); + + // Update the mock metadata implementation after creating mockMetadata + (grpc.Metadata as unknown as MockInstance).mockImplementation(() => mockMetadata); + + service = new ASRService({ + serverUrl: 'test:50051', + auth: { + credentials: grpc.credentials.createInsecure() + } + }); + }); + + describe('recognize', () => { + it('should recognize audio with correct parameters', async () => { + const mockResponse = { + results: [{ + alternatives: [{ + transcript: 'test transcript', + confidence: 0.9, + words: [{ + word: 'test', + startTime: 0, + endTime: 1, + confidence: 0.9, + speakerLabel: 'speaker_0' + }] + }] + }] + }; + + mockClient.recognize.mockResolvedValue(mockResponse); + + const config: RecognitionConfig = { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHertz: 16000, + languageCode: 'en-US', + maxAlternatives: 1, + enableAutomaticPunctuation: true + }; + + const result = await service.recognize(new Uint8Array(100), config); + + expect(mockClient.recognize).toHaveBeenCalledWith({ + config, + audio: { + content: expect.any(Uint8Array) + } + }); + + expect(result).toEqual(mockResponse); + }); + + it('should handle gRPC errors properly', async () => { + const error = new Error('Network error'); + mockClient.recognize.mockRejectedValue(error); + + const config: RecognitionConfig = { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHertz: 16000, + languageCode: 'en-US', + maxAlternatives: 1, + enableAutomaticPunctuation: true + }; + + await expect(service.recognize(new Uint8Array(100), config)).rejects.toThrow('Network error'); + }); + }); + + describe('streamingRecognize', () => { + it('should handle streaming recognition', async () => { + const mockStream = createMockStream({ + onData: () => ({ + results: [{ + alternatives: [{ + transcript: 'test transcript', + confidence: 0.9, + words: [] + }] + }] + }) + }); + mockClient.streamingRecognize.mockReturnValue(mockStream); + + const config: StreamingRecognitionConfig = { + config: { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHertz: 16000, + languageCode: 'en-US', + maxAlternatives: 1, + enableAutomaticPunctuation: true + } + }; + + const audioSource = { + content: new Uint8Array(100) + }; + + const stream = service.streamingRecognize(audioSource, config); + + // Collect all responses + const responses: StreamingRecognizeResponse[] = []; + for await (const response of stream) { + responses.push(response); + } + + expect(mockClient.streamingRecognize).toHaveBeenCalled(); + expect(mockStream.write).toHaveBeenCalledWith({ streamingConfig: config }); + expect(mockStream.write).toHaveBeenCalledWith({ audioContent: audioSource.content }); + expect(mockStream.end).toHaveBeenCalled(); + }); + + it('should handle streaming errors', async () => { + const mockStream = createMockStream({ + onError: (error) => { + throw error; + } + }); + + mockClient.streamingRecognize.mockReturnValue(mockStream); + + const config: StreamingRecognitionConfig = { + config: { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHertz: 16000, + languageCode: 'en-US', + maxAlternatives: 1, + enableAutomaticPunctuation: true + } + }; + + const audioSource = { + content: new Uint8Array(100) + }; + + const stream = service.streamingRecognize(audioSource, config); + + await expect(async () => { + for await (const _ of stream) { + // Just iterate to trigger error + } + }).rejects.toThrow('Stream error'); + }); + }); + + describe('listModels', () => { + it('should list available models', async () => { + const mockResponse = { + models: [{ + name: 'test-model', + languages: ['en-US'], + sample_rate: 16000, + streaming_supported: true + }] + }; + + const expectedResult = [{ + name: 'test-model', + languages: ['en-US'], + sampleRate: 16000, + streaming: true + }]; + + mockClient.listModels.mockResolvedValue(mockResponse); + + const result = await service.listModels(); + + expect(mockClient.listModels).toHaveBeenCalled(); + expect(result).toEqual(expectedResult); + }); + + it('should handle list models error', async () => { + const error = new Error('Failed to list models'); + mockClient.listModels.mockRejectedValue(error); + + await expect(service.listModels()).rejects.toThrow('Failed to list models'); + }); + }); +}); diff --git a/riva-ts-client/tests/unit/audio.test.ts b/riva-ts-client/tests/unit/audio.test.ts new file mode 100644 index 00000000..37baf54f --- /dev/null +++ b/riva-ts-client/tests/unit/audio.test.ts @@ -0,0 +1,349 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { AudioDeviceManagerImpl, MicrophoneStream, SoundCallback } from '../../src/client/audio/io'; +import { AudioDeviceInfo, MicrophoneStreamOptions, SoundCallbackOptions } from '../../src/client/audio/types'; +import { createMockAudioContext, createMockMediaDevices, createMockMediaStream } from './helpers'; + +describe('AudioDeviceManagerImpl', () => { + let deviceManager: AudioDeviceManagerImpl; + const mockAudioContext = createMockAudioContext(); + const mockMediaDevices = createMockMediaDevices(); + + beforeEach(() => { + vi.clearAllMocks(); + (global as any).AudioContext = vi.fn().mockImplementation(() => mockAudioContext); + (global.navigator as any).mediaDevices = mockMediaDevices; + deviceManager = new AudioDeviceManagerImpl(); + }); + + describe('getDeviceInfo', () => { + it('should return device info for input device', async () => { + const mockDevice = { + deviceId: '1', + kind: 'audioinput', + label: 'Test Microphone' + } as MediaDeviceInfo; + mockMediaDevices.enumerateDevices.mockResolvedValue([mockDevice]); + + const info = await deviceManager.getDeviceInfo(0); + + expect(info).toEqual({ + index: 0, + name: 'Test Microphone', + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }); + }); + + it('should throw error for non-existent device', async () => { + mockMediaDevices.enumerateDevices.mockResolvedValue([]); + + await expect(deviceManager.getDeviceInfo(0)) + .rejects.toThrow('Device with ID 0 not found'); + }); + }); + + describe('getDefaultInputDeviceInfo', () => { + it('should return info for default input device', async () => { + const mockDevice = { + deviceId: '1', + kind: 'audioinput', + label: 'Test Microphone' + } as MediaDeviceInfo; + mockMediaDevices.enumerateDevices.mockResolvedValue([mockDevice]); + + const info = await deviceManager.getDefaultInputDeviceInfo(); + + expect(info).toEqual({ + index: 0, + name: 'Test Microphone', + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }); + }); + + it('should return null when no input devices found', async () => { + mockMediaDevices.enumerateDevices.mockResolvedValue([]); + const info = await deviceManager.getDefaultInputDeviceInfo(); + expect(info).toBeNull(); + }); + }); + + describe('listInputDevices', () => { + it('should list all input devices', async () => { + const mockDevices = [ + { deviceId: '1', kind: 'audioinput', label: 'Mic 1' }, + { deviceId: '2', kind: 'audioinput', label: 'Mic 2' } + ] as MediaDeviceInfo[]; + mockMediaDevices.enumerateDevices.mockResolvedValue(mockDevices); + + const devices = await deviceManager.listInputDevices(); + + expect(devices).toEqual([ + { + index: 0, + name: 'Mic 1', + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }, + { + index: 1, + name: 'Mic 2', + maxInputChannels: 1, + maxOutputChannels: 0, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + } + ]); + }); + }); + + describe('listOutputDevices', () => { + it('should list all output devices', async () => { + const mockDevices = [ + { deviceId: '1', kind: 'audiooutput', label: 'Speaker 1' }, + { deviceId: '2', kind: 'audiooutput', label: 'Speaker 2' } + ] as MediaDeviceInfo[]; + mockMediaDevices.enumerateDevices.mockResolvedValue(mockDevices); + + const devices = await deviceManager.listOutputDevices(); + + expect(devices).toEqual([ + { + index: 0, + name: 'Speaker 1', + maxInputChannels: 0, + maxOutputChannels: 2, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + }, + { + index: 1, + name: 'Speaker 2', + maxInputChannels: 0, + maxOutputChannels: 2, + defaultSampleRate: mockAudioContext.sampleRate, + defaultLowInputLatency: 0, + defaultLowOutputLatency: 0, + defaultHighInputLatency: 0, + defaultHighOutputLatency: 0 + } + ]); + }); + }); +}); + +describe('MicrophoneStream', () => { + let micStream: MicrophoneStream; + const mockAudioContext = createMockAudioContext(); + const mockMediaDevices = createMockMediaDevices(); + const mockTrack = { + stop: vi.fn(), + enabled: true + }; + const mockMediaStream = { + getTracks: vi.fn().mockReturnValue([mockTrack]) + }; + + beforeEach(() => { + vi.clearAllMocks(); + (global as any).AudioContext = vi.fn().mockImplementation(() => mockAudioContext); + (global.navigator as any).mediaDevices = mockMediaDevices; + mockMediaDevices.getUserMedia.mockResolvedValue(mockMediaStream); + + const options: MicrophoneStreamOptions = { + rate: 16000, + chunk: 1024, + device: 1 + }; + micStream = new MicrophoneStream(options); + }); + + describe('start', () => { + it('should set up audio processing chain', async () => { + const mockSourceNode = { + connect: vi.fn(), + disconnect: vi.fn() + }; + const mockProcessorNode = { + connect: vi.fn(), + disconnect: vi.fn(), + onaudioprocess: null as any + }; + + mockAudioContext.createMediaStreamSource.mockReturnValue(mockSourceNode); + mockAudioContext.createScriptProcessor.mockReturnValue(mockProcessorNode); + + await micStream.start(); + + expect(mockMediaDevices.getUserMedia).toHaveBeenCalled(); + expect(mockAudioContext.createMediaStreamSource).toHaveBeenCalledWith(mockMediaStream); + expect(mockSourceNode.connect).toHaveBeenCalled(); + expect(mockProcessorNode.connect).toHaveBeenCalledWith(mockAudioContext.destination); + }); + + it('should emit error on getUserMedia failure', async () => { + const error = new Error('Permission denied'); + mockMediaDevices.getUserMedia.mockRejectedValue(error); + + await expect(micStream.start()).rejects.toThrow('Permission denied'); + }); + + it('should not start if already active', async () => { + const mockSourceNode = { + connect: vi.fn(), + disconnect: vi.fn() + }; + mockAudioContext.createMediaStreamSource.mockReturnValue(mockSourceNode); + + await micStream.start(); + await micStream.start(); + + expect(mockMediaDevices.getUserMedia).toHaveBeenCalledTimes(1); + }); + }); + + describe('stop', () => { + it('should clean up resources', async () => { + const mockSourceNode = { + connect: vi.fn(), + disconnect: vi.fn() + }; + const mockProcessorNode = { + connect: vi.fn(), + disconnect: vi.fn(), + onaudioprocess: null as any + }; + + mockAudioContext.createMediaStreamSource.mockReturnValue(mockSourceNode); + mockAudioContext.createScriptProcessor.mockReturnValue(mockProcessorNode); + + await micStream.start(); + micStream.stop(); + + expect(mockTrack.stop).toHaveBeenCalled(); + expect(mockAudioContext.close).toHaveBeenCalled(); + expect(mockSourceNode.disconnect).toHaveBeenCalled(); + expect(mockProcessorNode.disconnect).toHaveBeenCalled(); + }); + + it('should do nothing if not active', () => { + micStream.stop(); + expect(mockAudioContext.close).not.toHaveBeenCalled(); + }); + }); + + describe('pause/resume', () => { + it('should toggle track enabled state', async () => { + await micStream.start(); + + micStream.pause(); + expect(mockTrack.enabled).toBe(false); + + micStream.resume(); + expect(mockTrack.enabled).toBe(true); + }); + }); + + describe('isActive', () => { + it('should return correct active state', async () => { + expect(micStream.isActive()).toBe(false); + + await micStream.start(); + expect(micStream.isActive()).toBe(true); + + micStream.stop(); + expect(micStream.isActive()).toBe(false); + }); + }); +}); + +describe('SoundCallback', () => { + let soundCallback: SoundCallback; + const mockAudioContext = createMockAudioContext(); + const options: SoundCallbackOptions = { + sampwidth: 2, + nchannels: 1, + framerate: 44100 + }; + + beforeEach(() => { + vi.clearAllMocks(); + (global as any).AudioContext = vi.fn().mockImplementation(() => mockAudioContext); + soundCallback = new SoundCallback(options); + }); + + describe('write', () => { + it('should process audio data correctly', async () => { + const mockBuffer = Buffer.from([1, 2, 3, 4]); + const mockAudioBuffer = { duration: 1 }; + const mockSource = { + buffer: null as AudioBuffer | null, + connect: vi.fn(), + start: vi.fn() + }; + + mockAudioContext.decodeAudioData.mockImplementation((_buffer, onSuccess) => { + if (onSuccess) { + onSuccess(mockAudioBuffer as AudioBuffer); + } + return Promise.resolve(mockAudioBuffer as AudioBuffer); + }); + mockAudioContext.createBufferSource.mockReturnValue(mockSource as unknown as AudioBufferSourceNode); + + await soundCallback.write(mockBuffer); + + expect(mockAudioContext.decodeAudioData).toHaveBeenCalled(); + expect(mockSource.buffer).toBe(mockAudioBuffer); + expect(mockSource.connect).toHaveBeenCalledWith(mockAudioContext.destination); + expect(mockSource.start).toHaveBeenCalled(); + }); + + it('should throw error when closed', async () => { + await soundCallback.close(); + await expect(soundCallback.write(Buffer.from([1, 2, 3, 4]))) + .rejects.toThrow('Sound callback is closed'); + }); + + it('should handle decodeAudioData failure', async () => { + const error = new Error('Failed to decode audio data'); + mockAudioContext.decodeAudioData.mockRejectedValue(error); + + await expect(soundCallback.write(Buffer.from([1, 2, 3, 4]))) + .rejects.toThrow('Failed to decode audio data'); + }); + }); + + describe('close', () => { + it('should close audio context', async () => { + await soundCallback.close(); + expect(mockAudioContext.close).toHaveBeenCalled(); + }); + + it('should only close once', async () => { + await soundCallback.close(); + await soundCallback.close(); + expect(mockAudioContext.close).toHaveBeenCalledTimes(1); + }); + }); +}); diff --git a/riva-ts-client/tests/unit/auth.test.ts b/riva-ts-client/tests/unit/auth.test.ts new file mode 100644 index 00000000..5a2044dc --- /dev/null +++ b/riva-ts-client/tests/unit/auth.test.ts @@ -0,0 +1,310 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { Auth, AuthOptions } from '../../src/client/auth'; +import * as grpc from '@grpc/grpc-js'; +import * as fs from 'fs'; +import { resolve } from 'path'; + +vi.mock('fs', () => ({ + readFileSync: vi.fn() +})); + +vi.mock('@grpc/grpc-js', () => { + let metadataStore = new Map(); + + class MetadataMock { + constructor() { + metadataStore = new Map(); + } + + add(key: string, value: string) { + const values = metadataStore.get(key) || []; + values.push(value); + metadataStore.set(key, values); + } + + get(key: string) { + return metadataStore.get(key) || []; + } + + getMap() { + const map: Record = {}; + metadataStore.forEach((values, key) => { + map[key] = values[0]; + }); + return map; + } + + set(key: string, value: string) { + metadataStore.set(key, [value]); + } + } + + return { + ...vi.importActual('@grpc/grpc-js'), + Channel: vi.fn().mockImplementation(() => ({ + getTarget: vi.fn().mockReturnValue('localhost:50051') + })), + credentials: { + createInsecure: vi.fn().mockReturnValue({}), + createSsl: vi.fn().mockReturnValue({}), + createFromMetadataGenerator: vi.fn().mockReturnValue({}), + combineChannelCredentials: vi.fn().mockReturnValue({}) + }, + Metadata: MetadataMock + }; +}); + +describe('Auth', () => { + const testUri = 'localhost:50051'; + let auth: Auth; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe('constructor with AuthOptions', () => { + it('should initialize with default values', () => { + auth = new Auth({ uri: testUri }); + expect(auth['uri']).toBe(testUri); + expect(auth['useSsl']).toBe(false); + expect(auth['sslCert']).toBeUndefined(); + expect(auth['metadata']).toEqual([]); + expect(auth['channelOptions']).toEqual({}); + expect(grpc.Channel).toHaveBeenCalledWith( + testUri, + expect.any(Object), + {} + ); + }); + + it('should initialize with custom values', () => { + const sslCert = 'test-cert.pem'; + const metadata: [string, string][] = [['key', 'value']]; + const channelOptions = { 'grpc.keepalive_time_ms': 10000 }; + + auth = new Auth({ + uri: testUri, + useSsl: true, + sslCert, + metadata, + channelOptions + }); + + expect(auth['uri']).toBe(testUri); + expect(auth['useSsl']).toBe(true); + expect(auth['sslCert']).toBe(sslCert); + expect(auth['metadata']).toEqual(metadata); + expect(auth['channelOptions']).toEqual(channelOptions); + expect(fs.readFileSync).toHaveBeenCalledWith(resolve(sslCert)); + }); + + it('should add api-key to metadata when provided', () => { + const apiKey = 'test-api-key'; + auth = new Auth({ + uri: testUri, + apiKey + }); + + expect(auth['metadata']).toEqual([['api-key', apiKey]]); + }); + + it('should combine provided metadata with api-key', () => { + const apiKey = 'test-api-key'; + const metadata: [string, string][] = [['custom-key', 'custom-value']]; + auth = new Auth({ + uri: testUri, + apiKey, + metadata + }); + + expect(auth['metadata']).toEqual([ + ['custom-key', 'custom-value'], + ['api-key', apiKey] + ]); + }); + + it('should use provided credentials if available', () => { + const customCredentials = grpc.credentials.createInsecure(); + auth = new Auth({ + uri: testUri, + credentials: customCredentials + }); + + expect(grpc.Channel).toHaveBeenCalledWith( + testUri, + customCredentials, + {} + ); + }); + }); + + describe('constructor with Python-style arguments', () => { + it('should initialize with default values', () => { + auth = new Auth(); + expect(auth['uri']).toBe('localhost:50051'); + expect(auth['useSsl']).toBe(false); + expect(auth['sslCert']).toBeUndefined(); + expect(auth['metadata']).toEqual([]); + expect(auth['channelOptions']).toEqual({}); + }); + + it('should initialize with custom values', () => { + const sslCert = 'test-cert.pem'; + const metadataArgs = [['key', 'value']]; + + auth = new Auth(sslCert, true, testUri, metadataArgs); + + expect(auth['uri']).toBe(testUri); + expect(auth['useSsl']).toBe(true); + expect(auth['sslCert']).toBe(sslCert); + expect(auth['metadata']).toEqual(metadataArgs); + expect(auth['channelOptions']).toEqual({}); + expect(fs.readFileSync).toHaveBeenCalledWith(resolve(sslCert)); + }); + + it('should throw error for invalid metadata format', () => { + expect(() => { + new Auth(undefined, false, testUri, [['single']]); + }).toThrow('Metadata should have 2 parameters in "key" "value" pair. Received 1 parameters.'); + }); + }); + + describe('getCallMetadata', () => { + it('should return empty metadata when none provided', () => { + auth = new Auth({ uri: testUri }); + const metadata = auth.getCallMetadata(); + expect(metadata instanceof grpc.Metadata).toBe(true); + expect(metadata.getMap()).toEqual({}); + }); + + it('should return metadata with provided values', () => { + auth = new Auth({ + uri: testUri, + metadata: [['key1', 'value1'], ['key2', 'value2']] + }); + const metadata = auth.getCallMetadata(); + expect(metadata instanceof grpc.Metadata).toBe(true); + expect(metadata.get('key1')).toEqual(['value1']); + expect(metadata.get('key2')).toEqual(['value2']); + }); + + it('should add multiple values for the same key', () => { + auth = new Auth({ + uri: testUri, + metadata: [['key1', 'value1'], ['key1', 'value2']] + }); + const metadata = auth.getCallMetadata(); + expect(metadata instanceof grpc.Metadata).toBe(true); + const key1Values = metadata.get('key1'); + expect(key1Values).toHaveLength(2); + expect(key1Values).toContain('value1'); + expect(key1Values).toContain('value2'); + }); + }); + + describe('getAuthMetadata', () => { + it('should return empty array when no metadata', () => { + auth = new Auth({ uri: testUri }); + expect(auth.getAuthMetadata()).toEqual([]); + }); + + it('should return metadata array with provided values', () => { + const metadata: [string, string][] = [['key1', 'value1'], ['key2', 'value2']]; + auth = new Auth({ + uri: testUri, + metadata + }); + expect(auth.getAuthMetadata()).toEqual(metadata); + }); + + it('should return metadata including api-key', () => { + const metadata: [string, string][] = [['key1', 'value1']]; + const apiKey = 'test-api-key'; + auth = new Auth({ + uri: testUri, + metadata, + apiKey + }); + expect(auth.getAuthMetadata()).toEqual([ + ['key1', 'value1'], + ['api-key', apiKey] + ]); + }); + }); + + describe('channel creation', () => { + it('should create insecure channel by default', () => { + auth = new Auth({ uri: testUri }); + expect(grpc.credentials.createInsecure).toHaveBeenCalled(); + expect(grpc.credentials.createSsl).not.toHaveBeenCalled(); + }); + + it('should create SSL channel when useSsl is true', () => { + auth = new Auth({ + uri: testUri, + useSsl: true + }); + expect(grpc.credentials.createSsl).toHaveBeenCalledWith(null); + expect(grpc.credentials.createInsecure).not.toHaveBeenCalled(); + }); + + it('should create SSL channel with cert when provided', () => { + const sslCert = 'test-cert.pem'; + const certBuffer = Buffer.from('test-cert-content'); + vi.mocked(fs.readFileSync).mockReturnValue(certBuffer); + + auth = new Auth({ + uri: testUri, + sslCert + }); + expect(grpc.credentials.createSsl).toHaveBeenCalledWith(certBuffer); + expect(fs.readFileSync).toHaveBeenCalledWith(resolve(sslCert)); + }); + + it('should combine channel credentials with metadata when using SSL', () => { + const metadata: [string, string][] = [['key', 'value']]; + auth = new Auth({ + uri: testUri, + useSsl: true, + metadata + }); + expect(grpc.credentials.createFromMetadataGenerator).toHaveBeenCalled(); + expect(grpc.credentials.combineChannelCredentials).toHaveBeenCalled(); + }); + + it('should use provided channel options', () => { + const channelOptions = { 'grpc.keepalive_time_ms': 10000 }; + auth = new Auth({ + uri: testUri, + channelOptions + }); + expect(grpc.Channel).toHaveBeenCalledWith( + testUri, + expect.any(Object), + channelOptions + ); + }); + + it('should handle file read errors gracefully', () => { + const sslCert = 'nonexistent.pem'; + vi.mocked(fs.readFileSync).mockImplementation(() => { + throw new Error('File not found'); + }); + + expect(() => { + new Auth({ + uri: testUri, + sslCert + }); + }).toThrow(); + }); + }); + + describe('deprecated methods', () => { + it('createChannel should return the same channel instance', () => { + auth = new Auth({ uri: testUri }); + const channel = auth.createChannel(); + expect(channel).toBe(auth.channel); + }); + }); +}); diff --git a/riva-ts-client/tests/unit/helpers.ts b/riva-ts-client/tests/unit/helpers.ts new file mode 100644 index 00000000..f26cdfea --- /dev/null +++ b/riva-ts-client/tests/unit/helpers.ts @@ -0,0 +1,217 @@ +import { vi, type Mock } from 'vitest'; +import { Auth } from '../../src/client/auth'; +import * as grpc from '@grpc/grpc-js'; +import { ClientDuplexStream, Metadata, StatusObject } from '@grpc/grpc-js'; +import { RivaConfig } from '../../src/client/types'; + +type MockAuthConfig = { + serverUrl: string; + auth: { + ssl: boolean; + sslCert?: string; + metadata?: Array<[string, string]>; + }; +}; + +/** + * Creates a mock Auth instance for testing + * @returns Tuple of [mockConfig, mockMetadata] + */ +export function createAuthMock(): [MockAuthConfig, grpc.Metadata] { + const mockMetadata = new grpc.Metadata(); + mockMetadata.set('test-key', 'test-value'); + + const mockConfig: MockAuthConfig = { + serverUrl: 'localhost:50051', + auth: { + ssl: false, + metadata: [['test-key', 'test-value']] + } + }; + + return [mockConfig, mockMetadata]; +} + +type GrpcEventType = 'data' | 'end' | 'error' | 'status' | 'metadata' | 'close'; + +/** + * Creates a mock gRPC stream for testing + */ +export function createMockStream(): ClientDuplexStream { + const eventHandlers: Map = new Map(); + + const mockStream = { + on: vi.fn().mockImplementation((event: GrpcEventType, handler: Function) => { + if (!eventHandlers.has(event)) { + eventHandlers.set(event, []); + } + eventHandlers.get(event)!.push(handler); + return mockStream; + }), + + write: vi.fn().mockImplementation((data: TReq) => { + const handlers = eventHandlers.get('data') || []; + handlers.forEach(handler => handler(data)); + return true; + }), + + end: vi.fn().mockImplementation(() => { + const handlers = eventHandlers.get('end') || []; + handlers.forEach(handler => handler()); + }), + + destroy: vi.fn().mockImplementation((error?: Error) => { + if (error) { + const handlers = eventHandlers.get('error') || []; + handlers.forEach(handler => handler(error)); + } + const closeHandlers = eventHandlers.get('close') || []; + closeHandlers.forEach(handler => handler()); + }), + + emit: vi.fn().mockImplementation((event: string, ...args: any[]) => { + const handlers = eventHandlers.get(event) || []; + handlers.forEach(handler => handler(...args)); + return true; + }), + + removeListener: vi.fn(), + removeAllListeners: vi.fn(), + pause: vi.fn(), + resume: vi.fn(), + isPaused: vi.fn().mockReturnValue(false), + pipe: vi.fn(), + unpipe: vi.fn(), + unshift: vi.fn(), + wrap: vi.fn(), + [Symbol.asyncIterator]: vi.fn().mockImplementation(function* () { + const dataHandlers = eventHandlers.get('data') || []; + for (const handler of dataHandlers) { + yield handler; + } + }) + } as unknown as ClientDuplexStream; + + return mockStream; +} + +export class MockAudioContext { + createMediaStreamSource: Mock; + createScriptProcessor: Mock; + decodeAudioData: Mock; + createBufferSource: Mock; + destination: {}; + close: Mock; + sampleRate: number; + + constructor() { + this.createMediaStreamSource = vi.fn(); + this.createScriptProcessor = vi.fn(); + this.decodeAudioData = vi.fn(); + this.createBufferSource = vi.fn(); + this.destination = {}; + this.close = vi.fn(); + this.sampleRate = 44100; + } +} + +/** + * Creates a mock AudioContext for testing + */ +export function createMockAudioContext(): MockAudioContext { + return new MockAudioContext(); +} + +export class MockMediaDevices { + getUserMedia: Mock; + enumerateDevices: Mock; + + constructor() { + this.getUserMedia = vi.fn(); + this.enumerateDevices = vi.fn(); + } +} + +/** + * Creates a mock MediaDevices for testing + */ +export function createMockMediaDevices(): MockMediaDevices { + return new MockMediaDevices(); +} + +export class MockMediaStream { + getTracks: Mock; + + constructor() { + this.getTracks = vi.fn().mockReturnValue([]); + } +} + +/** + * Creates a mock MediaStream for testing + */ +export function createMockMediaStream(): MockMediaStream { + return new MockMediaStream(); +} + +export class MockGrpcClient { + recognize: Mock; + streamingRecognize: Mock; + synthesize: Mock; + streamingSynthesize: Mock; + classify: Mock; + tokenClassify: Mock; + analyzeEntities: Mock; + analyzeIntent: Mock; + transformText: Mock; + naturalQuery: Mock; + translateText: Mock; + streamingTranslateSpeechToSpeech: Mock; + streamingTranslateSpeechToText: Mock; + listSupportedLanguagePairs: Mock; + + constructor() { + this.recognize = vi.fn(); + this.streamingRecognize = vi.fn(); + this.synthesize = vi.fn(); + this.streamingSynthesize = vi.fn(); + this.classify = vi.fn(); + this.tokenClassify = vi.fn(); + this.analyzeEntities = vi.fn(); + this.analyzeIntent = vi.fn(); + this.transformText = vi.fn(); + this.naturalQuery = vi.fn(); + this.translateText = vi.fn(); + this.streamingTranslateSpeechToSpeech = vi.fn(); + this.streamingTranslateSpeechToText = vi.fn(); + this.listSupportedLanguagePairs = vi.fn(); + } +} + +/** + * Creates a mock gRPC client for testing + */ +export function createMockGrpcClient(): MockGrpcClient { + return new MockGrpcClient(); +} + +export class MockBuffer { + length: number; + slice: Mock; + toString: Mock; + readInt16LE: Mock; + + constructor(data: number[] = []) { + this.length = data.length; + this.slice = vi.fn().mockReturnThis(); + this.toString = vi.fn().mockReturnValue(''); + this.readInt16LE = vi.fn().mockReturnValue(0); + } +} + +/** + * Creates a mock Buffer for testing + */ +export function createMockBuffer(data: number[] = []): MockBuffer { + return new MockBuffer(data); +} diff --git a/riva-ts-client/tests/unit/helpers/audio.ts b/riva-ts-client/tests/unit/helpers/audio.ts new file mode 100644 index 00000000..fd3ed3b5 --- /dev/null +++ b/riva-ts-client/tests/unit/helpers/audio.ts @@ -0,0 +1,66 @@ +import { vi } from 'vitest'; + +/** + * Create mocks for Web Audio API components + */ +export const createAudioMocks = () => { + const mockTrack = { + stop: vi.fn(), + enabled: true, + kind: 'audio', + label: 'mock-track', + id: 'mock-track-id', + muted: false, + readyState: 'live', + applyConstraints: vi.fn(), + clone: vi.fn(), + getCapabilities: vi.fn(), + getConstraints: vi.fn(), + getSettings: vi.fn() + }; + + const mockMediaStream = { + active: true, + id: 'mock-stream-id', + getTracks: () => [mockTrack], + getAudioTracks: () => [mockTrack], + addTrack: vi.fn(), + clone: vi.fn(), + getTrackById: vi.fn(), + removeTrack: vi.fn() + }; + + const mockAudioContext = { + state: 'running', + sampleRate: 44100, + destination: {}, + listener: {}, + currentTime: 0, + decodeAudioData: vi.fn(), + createBuffer: vi.fn(), + createBufferSource: vi.fn(), + createMediaStreamSource: vi.fn(), + createAnalyser: vi.fn(), + createBiquadFilter: vi.fn(), + createGain: vi.fn(), + createOscillator: vi.fn(), + createPanner: vi.fn(), + createDynamicsCompressor: vi.fn(), + close: vi.fn(), + suspend: vi.fn(), + resume: vi.fn() + }; + + return { + mockTrack, + mockMediaStream, + mockAudioContext, + mockAudioBuffer: { + duration: 1.0, + length: 44100, + numberOfChannels: 1, + sampleRate: 44100, + getChannelData: vi.fn().mockReturnValue(new Float32Array(44100)) + } + }; +}; diff --git a/riva-ts-client/tests/unit/helpers/grpc.ts b/riva-ts-client/tests/unit/helpers/grpc.ts new file mode 100644 index 00000000..5f97922e --- /dev/null +++ b/riva-ts-client/tests/unit/helpers/grpc.ts @@ -0,0 +1,24 @@ +import { vi } from 'vitest'; + +/** + * Create a typed gRPC mock with vi.fn() for each method + */ +export const createGrpcMock = >(methods: Array) => { + const mock: Partial<{ [K in keyof T]: ReturnType }> = {}; + methods.forEach(method => { + mock[method] = vi.fn(); + }); + return mock as { [K in keyof T]: ReturnType }; +}; + +/** + * Create a mock gRPC metadata instance + */ +export const createMetadataMock = () => ({ + get: vi.fn(), + set: vi.fn(), + getMap: vi.fn(), + clone: vi.fn(), + merge: vi.fn(), + toHttp2Headers: vi.fn() +}); diff --git a/riva-ts-client/tests/unit/helpers/stream.ts b/riva-ts-client/tests/unit/helpers/stream.ts new file mode 100644 index 00000000..011920af --- /dev/null +++ b/riva-ts-client/tests/unit/helpers/stream.ts @@ -0,0 +1,55 @@ +import { vi, type Mock } from 'vitest'; + +export type MockStreamEvents = 'data' | 'error' | 'end' | 'close'; + +export interface MockStream { + write: Mock; + end: Mock; + on: Mock<[event: MockStreamEvents, callback: (...args: any[]) => void], any>; + removeListener: Mock; + [Symbol.asyncIterator](): AsyncIterator; +} + +export interface MockStreamOptions { + onError?: (error: Error) => void; + onData?: (data: any) => void; + onEnd?: () => void; +} + +/** + * Creates a mock stream with vitest mock functions + */ +export const createMockStream = (options?: MockStreamOptions): MockStream => { + const write = vi.fn(); + const end = vi.fn(); + const removeListener = vi.fn(); + const on = vi.fn((event: MockStreamEvents, callback: (...args: any[]) => void) => { + if (event === 'error' && options?.onError) { + callback(new Error('Stream error')); + } + if (event === 'data' && options?.onData) { + callback({}); + } + if (event === 'end' && options?.onEnd) { + callback(); + } + }); + + return { + write, + end, + on, + removeListener, + async *[Symbol.asyncIterator]() { + if (options?.onError) { + throw new Error('Stream error'); + } + if (options?.onData) { + yield {}; + } + if (options?.onEnd) { + return; + } + } + }; +}; diff --git a/riva-ts-client/tests/unit/nlp.test.ts b/riva-ts-client/tests/unit/nlp.test.ts new file mode 100644 index 00000000..508df01b --- /dev/null +++ b/riva-ts-client/tests/unit/nlp.test.ts @@ -0,0 +1,292 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { NLPService } from '../../src/client/nlp'; +import { RivaError } from '../../src/client/errors'; +import { createGrpcMock } from './helpers/grpc'; +import type { + ClassifyResponse, + TokenClassifyResponse, + TransformTextResponse, + AnalyzeEntitiesResponse, + AnalyzeIntentResponse, + NaturalQueryResponse, + AnalyzeIntentResponse_Slot +} from '../../src/proto/riva_nlp'; +import { status } from '@grpc/grpc-js'; +import type { ServiceError } from '@grpc/grpc-js'; +import { Metadata } from '@grpc/grpc-js'; + +const mockClient = createGrpcMock([ + 'Classify', + 'TokenClassify', + 'TransformText', + 'AnalyzeEntities', + 'AnalyzeIntent', + 'NaturalQuery' +]); + +// Mock dependencies +vi.mock('@grpc/grpc-js', async () => { + const actual = await vi.importActual('@grpc/grpc-js'); + return { + ...actual, + credentials: { + createInsecure: vi.fn(), + createFromMetadataGenerator: vi.fn() + }, + Metadata: vi.fn(), + Channel: vi.fn().mockImplementation(() => ({ + getTarget: vi.fn(), + close: vi.fn(), + getConnectivityState: vi.fn(), + watchConnectivityState: vi.fn() + })) + }; +}); + +// Mock getProtoClient +vi.mock('../../src/client/utils/proto', () => ({ + getProtoClient: vi.fn().mockReturnValue({ + RivaNLPServiceClient: vi.fn().mockImplementation(() => mockClient) + }) +})); + +describe('NLPService', () => { + let service: NLPService; + const mockConfig = { + serverUrl: 'localhost:50051', + auth: { + credentials: {} + } + }; + + beforeEach(() => { + vi.clearAllMocks(); + service = new NLPService(mockConfig); + }); + + describe('classifyText', () => { + it('should classify text with correct parameters', async () => { + const mockResponse: ClassifyResponse = { + results: [{ + label: 'test', + score: 0.9 + }] + }; + + mockClient.Classify.mockResolvedValue(mockResponse); + + const result = await service.classifyText('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.Classify).toHaveBeenCalledWith({ + text: ['test text'], + model: { modelName: 'test-model', languageCode: 'en-US' } + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.Classify.mockRejectedValue(mockError); + + await expect(service.classifyText('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.classifyText('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('classifyTokens', () => { + it('should classify tokens with correct parameters', async () => { + const mockResponse: TokenClassifyResponse = { + results: [{ + tokens: [{ + text: 'token1', + label: 'label1', + score: 0.9, + start: 0, + end: 6 + }] + }] + }; + + mockClient.TokenClassify.mockResolvedValue(mockResponse); + + const result = await service.classifyTokens('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.TokenClassify).toHaveBeenCalledWith({ + text: ['test text'], + model: { modelName: 'test-model', languageCode: 'en-US' } + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.TokenClassify.mockRejectedValue(mockError); + + await expect(service.classifyTokens('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.classifyTokens('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('transformText', () => { + it('should transform text with correct parameters', async () => { + const mockResponse: TransformTextResponse = { + text: 'transformed text' + }; + + mockClient.TransformText.mockResolvedValue(mockResponse); + + const result = await service.transformText('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.TransformText).toHaveBeenCalledWith({ + text: 'test text', + model: 'test-model' + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.TransformText.mockRejectedValue(mockError); + + await expect(service.transformText('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.transformText('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('punctuateText', () => { + it('should punctuate text with correct parameters', async () => { + const mockResponse: TransformTextResponse = { + text: 'punctuated text' + }; + + mockClient.TransformText.mockResolvedValue(mockResponse); + + const result = await service.punctuateText('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.TransformText).toHaveBeenCalledWith({ + text: 'test text', + model: 'test-model' + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.TransformText.mockRejectedValue(mockError); + + await expect(service.punctuateText('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.punctuateText('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('analyzeEntities', () => { + it('should analyze entities with correct parameters', async () => { + const mockResponse: AnalyzeEntitiesResponse = { + entities: [{ + text: 'test', + type: 'test', + score: 0.9, + start: 0, + end: 4 + }] + }; + + mockClient.AnalyzeEntities.mockResolvedValue(mockResponse); + + const result = await service.analyzeEntities('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.AnalyzeEntities).toHaveBeenCalledWith({ + text: 'test text' + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.AnalyzeEntities.mockRejectedValue(mockError); + + await expect(service.analyzeEntities('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.analyzeEntities('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('analyzeIntent', () => { + it('should analyze intent with correct parameters', async () => { + const mockSlot: AnalyzeIntentResponse_Slot = { + text: 'tomorrow', + type: 'date', + score: 0.9 + }; + const mockResponse = { + intent: 'set_alarm', + confidence: 0.95, + slots: [mockSlot] + }; + + mockClient.AnalyzeIntent.mockResolvedValue(mockResponse); + + const result = await service.analyzeIntent('test text', 'test-model'); + expect(result).toEqual(mockResponse); + expect(mockClient.AnalyzeIntent).toHaveBeenCalledWith({ + text: 'test text' + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.AnalyzeIntent.mockRejectedValue(mockError); + + await expect(service.analyzeIntent('test text', 'test-model')).rejects.toThrow(RivaError); + await expect(service.analyzeIntent('test text', 'test-model')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('naturalQuery', () => { + it('should process natural query with correct parameters', async () => { + const mockResponse: NaturalQueryResponse = { + response: 'test answer', + confidence: 0.9 + }; + + mockClient.NaturalQuery.mockResolvedValue(mockResponse); + + const result = await service.naturalQuery('test question', 'test context'); + expect(result).toEqual(mockResponse); + expect(mockClient.NaturalQuery).toHaveBeenCalledWith({ + query: 'test question', + context: 'test context' + }); + }); + + it('should handle gRPC errors properly', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.NaturalQuery.mockRejectedValue(mockError); + + await expect(service.naturalQuery('test question', 'test context')).rejects.toThrow(RivaError); + await expect(service.naturalQuery('test question', 'test context')).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); +}); diff --git a/riva-ts-client/tests/unit/nmt.test.ts b/riva-ts-client/tests/unit/nmt.test.ts new file mode 100644 index 00000000..a826cd1a --- /dev/null +++ b/riva-ts-client/tests/unit/nmt.test.ts @@ -0,0 +1,335 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { NeuralMachineTranslationService } from '../../src/client/nmt'; +import { RivaError } from '../../src/client/errors'; +import { createGrpcMock } from './helpers/grpc'; +import type { + TranslateRequest, + TranslateResponse, + AvailableLanguageRequest, + AvailableLanguageResponse, + StreamingS2SRequest, + StreamingS2SResponse, + StreamingS2TRequest, + StreamingS2TResponse, + LanguagePair, + StreamingS2SConfig, + StreamingS2TConfig, + StreamingRecognitionConfig, + RecognitionConfig, + TranslationConfig, + SynthesizeSpeechConfig, + NMTServiceClient +} from '../../src/client/nmt/types'; +import type { ServiceError } from '@grpc/grpc-js'; +import { Metadata, status } from '@grpc/grpc-js'; +import type { ClientReadableStream } from '@grpc/grpc-js'; + +// Create mock client before setting up mocks +const mockClient = createGrpcMock([ + 'translateText', + 'listSupportedLanguagePairs', + 'streamingTranslateSpeechToSpeech', + 'streamingTranslateSpeechToText' +]); + +// Mock dependencies +vi.mock('@grpc/grpc-js', async () => { + const actual = await vi.importActual('@grpc/grpc-js'); + return { + ...actual, + credentials: { + createInsecure: vi.fn(), + createFromMetadataGenerator: vi.fn() + }, + Metadata: vi.fn(), + Channel: vi.fn().mockImplementation(() => ({ + getTarget: vi.fn(), + close: vi.fn(), + getConnectivityState: vi.fn(), + watchConnectivityState: vi.fn() + })) + }; +}); + +vi.mock('../../src/client/utils/proto', () => ({ + getProtoClient: () => ({ + RivaSpeechTranslationClient: function() { + return mockClient; + } + }) +})); + +describe('NeuralMachineTranslationService', () => { + let service: NeuralMachineTranslationService; + const mockConfig = { + serverUrl: 'localhost:50051', + auth: { + ssl: false + } + }; + + beforeEach(() => { + vi.clearAllMocks(); + service = new NeuralMachineTranslationService(mockConfig); + }); + + describe('translate', () => { + const mockRequest: TranslateRequest = { + text: 'Hello world', + sourceLanguage: 'en-US', + targetLanguage: 'es-US' + }; + + it('should translate text successfully', async () => { + const mockResponse: TranslateResponse = { + translations: [{ + text: 'Hola mundo', + score: 0.95 + }], + text: 'Hola mundo', + score: 0.95 + }; + + mockClient.translateText.mockResolvedValue(mockResponse); + + const result = await service.translate(mockRequest); + expect(result).toEqual(mockResponse); + expect(mockClient.translateText).toHaveBeenCalledWith(mockRequest, expect.any(Metadata)); + }); + + it('should handle gRPC errors', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.translateText.mockRejectedValue(mockError); + + await expect(service.translate(mockRequest)).rejects.toThrow(RivaError); + await expect(service.translate(mockRequest)).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('get_supported_language_pairs', () => { + const mockModel = 'nmt-model'; + const mockRequest: AvailableLanguageRequest = { model: mockModel }; + + it('should get language pairs successfully', async () => { + const mockResponse: AvailableLanguageResponse = { + supportedLanguagePairs: [{ + sourceLanguageCode: 'en-US', + targetLanguageCode: 'es-US' + }] + }; + + mockClient.listSupportedLanguagePairs.mockResolvedValue(mockResponse); + + const result = await service.get_supported_language_pairs(mockModel); + expect(result).toEqual(mockResponse); + expect(mockClient.listSupportedLanguagePairs).toHaveBeenCalledWith(mockRequest, expect.any(Metadata)); + }); + + it('should handle gRPC errors', async () => { + const mockError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockError.code = status.UNAVAILABLE; + mockError.details = 'Server is down for maintenance'; + mockError.metadata = new Metadata(); + + mockClient.listSupportedLanguagePairs.mockRejectedValue(mockError); + + await expect(service.get_supported_language_pairs(mockModel)).rejects.toThrow(RivaError); + await expect(service.get_supported_language_pairs(mockModel)).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('streaming_s2s_response_generator', () => { + const mockConfig: StreamingS2SConfig = { + asrConfig: { + config: { + languageCode: 'en-US', + audioEncoding: 1, + sampleRateHertz: 16000 + }, + interimResults: true + }, + translationConfig: { + sourceLanguageCode: 'en-US', + targetLanguageCode: 'es-US' + }, + ttsConfig: { + languageCode: 'es-US', + sampleRateHz: 16000 + } + }; + + function mockResponse(text: string, translation: string, isPartial: boolean): StreamingS2SResponse { + return { + result: { + transcript: text, + translation: translation, + audioContent: new Uint8Array(Buffer.from('mock audio')), + isPartial + } + }; + } + + it('should stream speech-to-speech translation', async () => { + const mockOn = vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'data') { + setTimeout(() => { + callback(mockResponse('Hello', 'Hola', true)); + callback(mockResponse('Hello world', 'Hola mundo', false)); + const endCallback = mockOn.mock.calls.find(([evt]) => evt === 'end')?.[1]; + if (endCallback) endCallback(); + }, 0); + } + return mockStream; + }); + + const mockStream = { + on: mockOn, + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + yield mockResponse('Hello', 'Hola', true); + yield mockResponse('Hello world', 'Hola mundo', false); + } + } as any as ClientReadableStream; + + mockClient.streamingTranslateSpeechToSpeech.mockReturnValue(mockStream); + + const audioChunks = [new Uint8Array(Buffer.from('chunk1')), new Uint8Array(Buffer.from('chunk2'))]; + const responses: StreamingS2SResponse[] = []; + + for await (const response of service.streaming_s2s_response_generator(audioChunks, mockConfig)) { + responses.push(response); + } + + expect(responses).toHaveLength(2); + expect(responses[0].result.transcript).toBe('Hello'); + expect(responses[0].result.translation).toBe('Hola'); + expect(responses[0].result.isPartial).toBe(true); + expect(responses[1].result.transcript).toBe('Hello world'); + expect(responses[1].result.translation).toBe('Hola mundo'); + expect(responses[1].result.isPartial).toBe(false); + }); + + it('should handle stream errors', async () => { + const mockStream = { + on: vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'error') { + setTimeout(() => { + callback(new Error('Stream error')); + }, 0); + } + return mockStream; + }), + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + throw new Error('Stream error'); + } + } as any as ClientReadableStream; + + mockClient.streamingTranslateSpeechToSpeech.mockReturnValue(mockStream); + + const audioChunks = [new Uint8Array(Buffer.from('chunk1'))]; + await expect(async () => { + for await (const _ of service.streaming_s2s_response_generator(audioChunks, mockConfig)) { + // consume stream + } + }).rejects.toThrow('Stream error'); + }); + }); + + describe('streaming_s2t_response_generator', () => { + const mockConfig: StreamingS2TConfig = { + asrConfig: { + config: { + languageCode: 'en-US', + audioEncoding: 1, + sampleRateHertz: 16000 + }, + interimResults: true + }, + translationConfig: { + sourceLanguageCode: 'en-US', + targetLanguageCode: 'es-US' + } + }; + + function mockResponse(text: string, translation: string, isPartial: boolean): StreamingS2TResponse { + return { + result: { + transcript: text, + translation: translation, + isPartial + } + }; + } + + it('should stream speech-to-text translation', async () => { + const mockOn = vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'data') { + setTimeout(() => { + callback(mockResponse('Hello', 'Hola', true)); + callback(mockResponse('Hello world', 'Hola mundo', false)); + const endCallback = mockOn.mock.calls.find(([evt]) => evt === 'end')?.[1]; + if (endCallback) endCallback(); + }, 0); + } + return mockStream; + }); + + const mockStream = { + on: mockOn, + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + yield mockResponse('Hello', 'Hola', true); + yield mockResponse('Hello world', 'Hola mundo', false); + } + } as any as ClientReadableStream; + + mockClient.streamingTranslateSpeechToText.mockReturnValue(mockStream); + + const audioChunks = [new Uint8Array(Buffer.from('chunk1')), new Uint8Array(Buffer.from('chunk2'))]; + const responses: StreamingS2TResponse[] = []; + + for await (const response of service.streaming_s2t_response_generator(audioChunks, mockConfig)) { + responses.push(response); + } + + expect(responses).toHaveLength(2); + expect(responses[0].result.transcript).toBe('Hello'); + expect(responses[0].result.translation).toBe('Hola'); + expect(responses[0].result.isPartial).toBe(true); + expect(responses[1].result.transcript).toBe('Hello world'); + expect(responses[1].result.translation).toBe('Hola mundo'); + expect(responses[1].result.isPartial).toBe(false); + }); + + it('should handle stream errors', async () => { + const mockStream = { + on: vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'error') { + setTimeout(() => { + callback(new Error('Stream error')); + }, 0); + } + return mockStream; + }), + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + throw new Error('Stream error'); + } + } as any as ClientReadableStream; + + mockClient.streamingTranslateSpeechToText.mockReturnValue(mockStream); + + const audioChunks = [new Uint8Array(Buffer.from('chunk1'))]; + await expect(async () => { + for await (const _ of service.streaming_s2t_response_generator(audioChunks, mockConfig)) { + // consume stream + } + }).rejects.toThrow('Stream error'); + }); + }); +}); diff --git a/riva-ts-client/tests/unit/tts.test.ts b/riva-ts-client/tests/unit/tts.test.ts new file mode 100644 index 00000000..f6be75ff --- /dev/null +++ b/riva-ts-client/tests/unit/tts.test.ts @@ -0,0 +1,289 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { SpeechSynthesisService } from '../../src/client/tts'; +import { AudioEncoding } from '../../src/client/asr/types'; +import { RivaError } from '../../src/client/errors'; +import { createMockStream } from './helpers'; +import { createGrpcMock } from './helpers/grpc'; +import * as fs from 'fs'; +import { WaveFile } from 'wavefile'; +import type { + SynthesizeSpeechRequest, + SynthesizeSpeechResponse, + RivaSynthesisConfigRequest, + RivaSynthesisConfigResponse, + RivaSpeechSynthesisStub +} from '../../src/client/tts/types'; +import type { ServiceError } from '@grpc/grpc-js'; +import { Metadata, status } from '@grpc/grpc-js'; +import type { ClientReadableStream } from '@grpc/grpc-js'; +import { EventEmitter } from 'events'; + +// Create mock client before setting up mocks +const mockClient = createGrpcMock(['GetRivaSynthesisConfig', 'Synthesize', 'SynthesizeOnline']); + +// Mock dependencies +vi.mock('fs', () => ({ + readFileSync: vi.fn() +})); + +vi.mock('wavefile', () => ({ + WaveFile: vi.fn().mockImplementation(() => ({ + getSamples: vi.fn(), + fmt: { sampleRate: 44100 } + })) +})); + +vi.mock('@grpc/grpc-js', async () => { + const actual = await vi.importActual('@grpc/grpc-js'); + return { + ...actual, + credentials: { + createInsecure: vi.fn(), + createFromMetadataGenerator: vi.fn() + }, + Metadata: vi.fn(), + Channel: vi.fn().mockImplementation(() => ({ + getTarget: vi.fn(), + close: vi.fn(), + getConnectivityState: vi.fn(), + watchConnectivityState: vi.fn() + })) + }; +}); + +vi.mock('../../src/client/utils/proto', () => ({ + getProtoClient: () => ({ + RivaSpeechSynthesisStub: function() { + return mockClient; + } + }) +})); + +describe('SpeechSynthesisService', () => { + let service: SpeechSynthesisService; + const mockConfig = { + serverUrl: 'localhost:50051', + auth: { + ssl: false, + credentials: undefined + } + }; + + beforeEach(() => { + vi.clearAllMocks(); + service = new SpeechSynthesisService(mockConfig); + }); + + describe('getRivaSynthesisConfig', () => { + it('should get synthesis config successfully', async () => { + const mockResponse: RivaSynthesisConfigResponse = { + modelConfig: [{ + parameters: { + languageCode: 'en-US', + voiceName: 'test-voice', + subvoices: 'voice1,voice2' + } + }] + }; + + mockClient.GetRivaSynthesisConfig.mockResolvedValue(mockResponse); + + const result = await service.getRivaSynthesisConfig(); + expect(result).toEqual(mockResponse); + expect(mockClient.GetRivaSynthesisConfig).toHaveBeenCalledWith({}, expect.any(Metadata)); + }); + + it('should handle gRPC errors properly', async () => { + const mockGrpcError = new Error('UNAVAILABLE: Server is currently unavailable') as ServiceError; + mockGrpcError.code = status.UNAVAILABLE; + mockGrpcError.details = 'Server is down for maintenance'; + mockGrpcError.metadata = new Metadata(); + + mockClient.GetRivaSynthesisConfig.mockRejectedValue(mockGrpcError); + + await expect(service.getRivaSynthesisConfig()).rejects.toThrow(Error); + await expect(service.getRivaSynthesisConfig()).rejects.toThrow('UNAVAILABLE: Server is currently unavailable'); + }); + }); + + describe('synthesize', () => { + const defaultRequest: SynthesizeSpeechRequest = { + text: 'test text', + languageCode: 'en-US', + sampleRateHz: 44100, + encoding: AudioEncoding.LINEAR_PCM + }; + + const defaultResponse: SynthesizeSpeechResponse = { + audio: new Uint8Array(Buffer.from('test audio')), + audioConfig: { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHz: 44100 + } + }; + + it('should synthesize with default parameters', async () => { + mockClient.Synthesize.mockResolvedValue(defaultResponse); + + const result = await service.synthesize('test text'); + expect(result).toEqual(defaultResponse); + expect(mockClient.Synthesize).toHaveBeenCalledWith(defaultRequest, expect.any(Metadata)); + }); + + it('should synthesize with custom voice', async () => { + const voiceName = 'English-US-Female-1'; + mockClient.Synthesize.mockResolvedValue(defaultResponse); + + await service.synthesize('test text', voiceName); + expect(mockClient.Synthesize).toHaveBeenCalledWith({ + ...defaultRequest, + voiceName + }, expect.any(Metadata)); + }); + + it('should handle zero-shot synthesis with audio prompt', async () => { + const audioPromptFile = 'test.wav'; + const mockSamples = new Float64Array([1, 2, 3]); + const mockWaveFile = { + getSamples: vi.fn().mockReturnValue(mockSamples), + fmt: { sampleRate: 44100 } + }; + vi.mocked(WaveFile).mockImplementation(() => mockWaveFile as any); + vi.mocked(fs.readFileSync).mockReturnValue(Buffer.from('test')); + + mockClient.Synthesize.mockResolvedValue(defaultResponse); + + await service.synthesize('test text', undefined, 'en-US', AudioEncoding.LINEAR_PCM, 44100, audioPromptFile); + + expect(mockClient.Synthesize).toHaveBeenCalledWith({ + ...defaultRequest, + zeroShotData: { + audioPrompt: expect.any(Buffer), + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHz: 44100, + quality: 20 + } + }, expect.any(Metadata)); + }); + + it('should handle invalid WAV file errors', async () => { + const audioPromptFile = 'test.wav'; + vi.mocked(WaveFile).mockImplementation(() => ({ + getSamples: vi.fn().mockReturnValue(null), + fmt: { sampleRate: 44100 } + }) as any); + + await expect( + service.synthesize('test text', undefined, 'en-US', AudioEncoding.LINEAR_PCM, 44100, audioPromptFile) + ).rejects.toThrow('Invalid WAV file: no samples found'); + }); + + it('should handle WAV file without sample rate', async () => { + const audioPromptFile = 'test.wav'; + vi.mocked(WaveFile).mockImplementation(() => ({ + getSamples: vi.fn().mockReturnValue(new Float64Array([1, 2, 3])), + fmt: {} + }) as any); + + await expect( + service.synthesize('test text', undefined, 'en-US', AudioEncoding.LINEAR_PCM, 44100, audioPromptFile) + ).rejects.toThrow('Invalid WAV file: no sample rate found'); + }); + + it('should add custom dictionary to request', async () => { + const customDictionary = { + 'word1': 'phoneme1', + 'word2': 'phoneme2' + }; + mockClient.Synthesize.mockResolvedValue(defaultResponse); + + await service.synthesize('test text', undefined, 'en-US', AudioEncoding.LINEAR_PCM, 44100, undefined, AudioEncoding.LINEAR_PCM, 20, false, customDictionary); + + expect(mockClient.Synthesize).toHaveBeenCalledWith({ + ...defaultRequest, + customDictionary: 'word1 phoneme1,word2 phoneme2' + }, expect.any(Metadata)); + }); + }); + + describe('synthesizeOnline', () => { + const defaultRequest: SynthesizeSpeechRequest = { + text: 'test text', + languageCode: 'en-US', + sampleRateHz: 44100, + encoding: AudioEncoding.LINEAR_PCM + }; + + function mockChunkResponse(data: string): SynthesizeSpeechResponse { + return { + audio: new Uint8Array(Buffer.from(data)), + audioConfig: { + encoding: AudioEncoding.LINEAR_PCM, + sampleRateHz: 44100 + } + }; + } + + it('should stream synthesis results', async () => { + // Create a mock stream that implements the ClientReadableStream interface + const mockOn = vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'data') { + setTimeout(() => { + callback(mockChunkResponse('chunk1')); + callback(mockChunkResponse('chunk2')); + const endCallback = mockOn.mock.calls.find(([evt]) => evt === 'end')?.[1]; + if (endCallback) endCallback(); + }, 0); + } + return mockStream; + }); + + const mockStream = { + on: mockOn, + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + yield mockChunkResponse('chunk1'); + yield mockChunkResponse('chunk2'); + } + } as any as ClientReadableStream; + + mockClient.SynthesizeOnline.mockReturnValue(mockStream); + + const chunks: SynthesizeSpeechResponse[] = []; + for await (const chunk of service.synthesizeOnline('test text')) { + chunks.push(chunk); + } + + expect(chunks).toHaveLength(2); + expect(chunks[0]).toEqual(mockChunkResponse('chunk1')); + expect(chunks[1]).toEqual(mockChunkResponse('chunk2')); + expect(mockClient.SynthesizeOnline).toHaveBeenCalledWith(defaultRequest, expect.any(Metadata)); + }); + + it('should handle stream errors', async () => { + // Create a mock stream that implements the ClientReadableStream interface + const mockStream = { + on: vi.fn((event: string, callback: (...args: any[]) => void) => { + if (event === 'error') { + setTimeout(() => { + callback(new Error('Stream error')); + }, 0); + } + return mockStream; + }), + removeListener: vi.fn(), + [Symbol.asyncIterator]: function* () { + throw new Error('Stream error'); + } + } as any as ClientReadableStream; + + mockClient.SynthesizeOnline.mockReturnValue(mockStream); + + await expect(async () => { + for await (const _ of service.synthesizeOnline('test text')) { + // consume stream + } + }).rejects.toThrow('Stream error'); + }); + }); +}); diff --git a/riva-ts-client/tsconfig.json b/riva-ts-client/tsconfig.json new file mode 100644 index 00000000..ab44fad4 --- /dev/null +++ b/riva-ts-client/tsconfig.json @@ -0,0 +1,25 @@ +{ + "compilerOptions": { + "target": "es2020", + "module": "commonjs", + "declaration": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true, + "sourceMap": true, + "baseUrl": ".", + "paths": { + "@/*": ["src/*"] + }, + "typeRoots": [ + "./node_modules/@types", + "./src/types" + ] + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "**/*.test.ts"] + } \ No newline at end of file diff --git a/riva-ts-client/vitest.config.ts b/riva-ts-client/vitest.config.ts new file mode 100644 index 00000000..6bd00184 --- /dev/null +++ b/riva-ts-client/vitest.config.ts @@ -0,0 +1,16 @@ +/// +import { defineConfig } from 'vitest/config'; + +export default defineConfig({ + test: { + globals: true, + environment: 'node', + include: ['tests/unit/**/*.test.ts'], + coverage: { + provider: 'v8', + reporter: ['text', 'json', 'html'], + include: ['src/**/*.ts'], + exclude: ['**/node_modules/**', '**/dist/**', '**/types/**'] + } + } +});