12
12
13
13
from grpc ._channel import _MultiThreadedRendezvous
14
14
15
+ import riva .client
15
16
import riva .client .proto .riva_asr_pb2 as rasr
16
17
import riva .client .proto .riva_asr_pb2_grpc as rasr_srv
17
18
from riva .client .auth import Auth
18
19
19
20
20
21
def get_wav_file_parameters (input_file : Union [str , os .PathLike ]) -> Dict [str , Union [int , float ]]:
21
- input_file = Path (input_file ).expanduser ()
22
- with wave .open (str (input_file ), 'rb' ) as wf :
23
- nframes = wf .getnframes ()
24
- rate = wf .getframerate ()
25
- parameters = {
26
- 'nframes' : nframes ,
27
- 'framerate' : rate ,
28
- 'duration' : nframes / rate ,
29
- 'nchannels' : wf .getnchannels (),
30
- 'sampwidth' : wf .getsampwidth (),
31
- }
22
+ try :
23
+ input_file = Path (input_file ).expanduser ()
24
+ with wave .open (str (input_file ), 'rb' ) as wf :
25
+ nframes = wf .getnframes ()
26
+ rate = wf .getframerate ()
27
+ parameters = {
28
+ 'nframes' : nframes ,
29
+ 'framerate' : rate ,
30
+ 'duration' : nframes / rate ,
31
+ 'nchannels' : wf .getnchannels (),
32
+ 'sampwidth' : wf .getsampwidth (),
33
+ 'data_offset' : wf .getfp ().size_read + wf .getfp ().offset
34
+ }
35
+ except :
36
+ # Not a WAV file
37
+ return None
32
38
return parameters
33
39
34
40
@@ -47,7 +53,11 @@ def __init__(
47
53
self .chunk_n_frames = chunk_n_frames
48
54
self .delay_callback = delay_callback
49
55
self .file_parameters = get_wav_file_parameters (self .input_file )
50
- self .file_object : Optional [wave .Wave_read ] = wave .open (str (self .input_file ), 'rb' )
56
+ self .file_object : Optional [typing .BinaryIO ] = open (str (self .input_file ), 'rb' )
57
+ if self .delay_callback and self .file_parameters is None :
58
+ warnings .warn (f"delay_callback not supported for encoding other than LINEAR_PCM" )
59
+ self .delay_callback = None
60
+ self .first_buffer = True
51
61
52
62
def close (self ) -> None :
53
63
self .file_object .close ()
@@ -64,15 +74,19 @@ def __iter__(self):
64
74
return self
65
75
66
76
def __next__ (self ) -> bytes :
67
- data = self .file_object .readframes (self .chunk_n_frames )
77
+ if self .file_parameters :
78
+ data = self .file_object .read (self .chunk_n_frames * self .file_parameters ['sampwidth' ] * self .file_parameters ['nchannels' ])
79
+ else :
80
+ data = self .file_object .read (self .chunk_n_frames )
68
81
if not data :
69
82
self .close ()
70
83
raise StopIteration
71
84
if self .delay_callback is not None :
85
+ offset = self .file_parameters ['data_offset' ] if self .first_buffer else 0
72
86
self .delay_callback (
73
- data ,
74
- len (data ) / self .file_parameters ['sampwidth' ] / self .file_parameters ['framerate' ]
87
+ data [offset :], (len (data ) - offset ) / self .file_parameters ['sampwidth' ] / self .file_parameters ['framerate' ]
75
88
)
89
+ self .first_buffer = False
76
90
return data
77
91
78
92
@@ -95,8 +109,9 @@ def add_audio_file_specs_to_config(
95
109
) -> None :
96
110
inner_config : rasr .RecognitionConfig = config if isinstance (config , rasr .RecognitionConfig ) else config .config
97
111
wav_parameters = get_wav_file_parameters (audio_file )
98
- inner_config .sample_rate_hertz = wav_parameters ['framerate' ]
99
- inner_config .audio_channel_count = wav_parameters ['nchannels' ]
112
+ if wav_parameters is not None :
113
+ inner_config .sample_rate_hertz = wav_parameters ['framerate' ]
114
+ inner_config .audio_channel_count = wav_parameters ['nchannels' ]
100
115
101
116
102
117
def add_speaker_diarization_to_config (
0 commit comments