Skip to content

Commit 2ab2e0e

Browse files
support file formats other than LINEAR_PCM (#47)
* Support file formats other than LINEAR_PCM Disable setting audio encoding parameters explicitly, let Riva do the decode. * add ffmpeg-python requirement * additional fixes and remove ffmpeg dependency
1 parent 9ab7d05 commit 2ab2e0e

File tree

4 files changed

+32
-25
lines changed

4 files changed

+32
-25
lines changed

riva/client/asr.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,29 @@
1212

1313
from grpc._channel import _MultiThreadedRendezvous
1414

15+
import riva.client
1516
import riva.client.proto.riva_asr_pb2 as rasr
1617
import riva.client.proto.riva_asr_pb2_grpc as rasr_srv
1718
from riva.client.auth import Auth
1819

1920

2021
def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Union[int, float]]:
21-
input_file = Path(input_file).expanduser()
22-
with wave.open(str(input_file), 'rb') as wf:
23-
nframes = wf.getnframes()
24-
rate = wf.getframerate()
25-
parameters = {
26-
'nframes': nframes,
27-
'framerate': rate,
28-
'duration': nframes / rate,
29-
'nchannels': wf.getnchannels(),
30-
'sampwidth': wf.getsampwidth(),
31-
}
22+
try:
23+
input_file = Path(input_file).expanduser()
24+
with wave.open(str(input_file), 'rb') as wf:
25+
nframes = wf.getnframes()
26+
rate = wf.getframerate()
27+
parameters = {
28+
'nframes': nframes,
29+
'framerate': rate,
30+
'duration': nframes / rate,
31+
'nchannels': wf.getnchannels(),
32+
'sampwidth': wf.getsampwidth(),
33+
'data_offset': wf.getfp().size_read + wf.getfp().offset
34+
}
35+
except:
36+
# Not a WAV file
37+
return None
3238
return parameters
3339

3440

@@ -47,7 +53,11 @@ def __init__(
4753
self.chunk_n_frames = chunk_n_frames
4854
self.delay_callback = delay_callback
4955
self.file_parameters = get_wav_file_parameters(self.input_file)
50-
self.file_object: Optional[wave.Wave_read] = wave.open(str(self.input_file), 'rb')
56+
self.file_object: Optional[typing.BinaryIO] = open(str(self.input_file), 'rb')
57+
if self.delay_callback and self.file_parameters is None:
58+
warnings.warn(f"delay_callback not supported for encoding other than LINEAR_PCM")
59+
self.delay_callback = None
60+
self.first_buffer = True
5161

5262
def close(self) -> None:
5363
self.file_object.close()
@@ -64,15 +74,19 @@ def __iter__(self):
6474
return self
6575

6676
def __next__(self) -> bytes:
67-
data = self.file_object.readframes(self.chunk_n_frames)
77+
if self.file_parameters:
78+
data = self.file_object.read(self.chunk_n_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'])
79+
else:
80+
data = self.file_object.read(self.chunk_n_frames)
6881
if not data:
6982
self.close()
7083
raise StopIteration
7184
if self.delay_callback is not None:
85+
offset = self.file_parameters['data_offset'] if self.first_buffer else 0
7286
self.delay_callback(
73-
data,
74-
len(data) / self.file_parameters['sampwidth'] / self.file_parameters['framerate']
87+
data[offset:], (len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate']
7588
)
89+
self.first_buffer = False
7690
return data
7791

7892

@@ -95,8 +109,9 @@ def add_audio_file_specs_to_config(
95109
) -> None:
96110
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
97111
wav_parameters = get_wav_file_parameters(audio_file)
98-
inner_config.sample_rate_hertz = wav_parameters['framerate']
99-
inner_config.audio_channel_count = wav_parameters['nchannels']
112+
if wav_parameters is not None:
113+
inner_config.sample_rate_hertz = wav_parameters['framerate']
114+
inner_config.audio_channel_count = wav_parameters['nchannels']
100115

101116

102117
def add_speaker_diarization_to_config(

scripts/asr/riva_streaming_asr_client.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def streaming_transcription_worker(
5454
asr_service = riva.client.ASRService(auth)
5555
config = riva.client.StreamingRecognitionConfig(
5656
config=riva.client.RecognitionConfig(
57-
encoding=riva.client.AudioEncoding.LINEAR_PCM,
5857
language_code=args.language_code,
5958
max_alternatives=args.max_alternatives,
6059
profanity_filter=args.profanity_filter,
@@ -64,7 +63,6 @@ def streaming_transcription_worker(
6463
),
6564
interim_results=True,
6665
)
67-
riva.client.add_audio_file_specs_to_config(config, args.input_file)
6866
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
6967
for _ in range(args.num_iterations):
7068
with riva.client.AudioChunkFileIterator(
@@ -92,8 +90,6 @@ def main() -> None:
9290
print("Number of clients:", args.num_clients)
9391
print("Number of iteration:", args.num_iterations)
9492
print("Input file:", args.input_file)
95-
wav_parameters = get_wav_file_parameters(args.input_file)
96-
print(f"File duration: {wav_parameters['duration']:.2f}s")
9793
threads = []
9894
exception_queue = queue.Queue()
9995
for i in range(args.num_clients):

scripts/asr/transcribe_file.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def main() -> None:
7070
asr_service = riva.client.ASRService(auth)
7171
config = riva.client.StreamingRecognitionConfig(
7272
config=riva.client.RecognitionConfig(
73-
encoding=riva.client.AudioEncoding.LINEAR_PCM,
7473
language_code=args.language_code,
7574
max_alternatives=1,
7675
profanity_filter=args.profanity_filter,
@@ -79,7 +78,6 @@ def main() -> None:
7978
),
8079
interim_results=True,
8180
)
82-
riva.client.add_audio_file_specs_to_config(config, args.input_file)
8381
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
8482
sound_callback = None
8583
try:

scripts/asr/transcribe_file_offline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,13 @@ def main() -> None:
2929
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
3030
asr_service = riva.client.ASRService(auth)
3131
config = riva.client.RecognitionConfig(
32-
encoding=riva.client.AudioEncoding.LINEAR_PCM,
3332
language_code=args.language_code,
3433
max_alternatives=args.max_alternatives,
3534
profanity_filter=args.profanity_filter,
3635
enable_automatic_punctuation=args.automatic_punctuation,
3736
verbatim_transcripts=not args.no_verbatim_transcripts,
3837
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
3938
)
40-
riva.client.add_audio_file_specs_to_config(config, args.input_file)
4139
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
4240
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
4341

0 commit comments

Comments
 (0)