11import torch
2+
23from ._ext import ctc_decode
34
45
@@ -17,13 +18,24 @@ class CTCBeamDecoder(object):
1718 cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
1819 beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
1920 but they also will make your beam search exponentially slower.
20- num_processes (int): Parallelize the batch using num_processes workers.
21+ num_processes (int): Parallelize the batch using num_processes workers.
2122 blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
2223 log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
2324 """
2425
25- def __init__ (self , labels , model_path = None , alpha = 0 , beta = 0 , cutoff_top_n = 40 , cutoff_prob = 1.0 , beam_width = 100 ,
26- num_processes = 4 , blank_id = 0 , log_probs_input = False ):
26+ def __init__ (
27+ self ,
28+ labels ,
29+ model_path = None ,
30+ alpha = 0 ,
31+ beta = 0 ,
32+ cutoff_top_n = 40 ,
33+ cutoff_prob = 1.0 ,
34+ beam_width = 100 ,
35+ num_processes = 4 ,
36+ blank_id = 0 ,
37+ log_probs_input = False ,
38+ ):
2739 self .cutoff_top_n = cutoff_top_n
2840 self ._beam_width = beam_width
2941 self ._scorer = None
@@ -33,8 +45,9 @@ def __init__(self, labels, model_path=None, alpha=0, beta=0, cutoff_top_n=40, cu
3345 self ._blank_id = blank_id
3446 self ._log_probs = 1 if log_probs_input else 0
3547 if model_path :
36- self ._scorer = ctc_decode .paddle_get_scorer (alpha , beta , model_path .encode (), self ._labels ,
37- self ._num_labels )
48+ self ._scorer = ctc_decode .paddle_get_scorer (
49+ alpha , beta , model_path .encode (), self ._labels , self ._num_labels
50+ )
3851 self ._cutoff_prob = cutoff_prob
3952
4053 def decode (self , probs , seq_lens = None ):
@@ -72,14 +85,40 @@ def decode(self, probs, seq_lens=None):
7285 scores = torch .FloatTensor (batch_size , self ._beam_width ).cpu ().float ()
7386 out_seq_len = torch .zeros (batch_size , self ._beam_width ).cpu ().int ()
7487 if self ._scorer :
75- ctc_decode .paddle_beam_decode_lm (probs , seq_lens , self ._labels , self ._num_labels , self ._beam_width ,
76- self ._num_processes , self ._cutoff_prob , self .cutoff_top_n , self ._blank_id ,
77- self ._log_probs , self ._scorer , output , timesteps , scores , out_seq_len )
88+ ctc_decode .paddle_beam_decode_lm (
89+ probs ,
90+ seq_lens ,
91+ self ._labels ,
92+ self ._num_labels ,
93+ self ._beam_width ,
94+ self ._num_processes ,
95+ self ._cutoff_prob ,
96+ self .cutoff_top_n ,
97+ self ._blank_id ,
98+ self ._log_probs ,
99+ self ._scorer ,
100+ output ,
101+ timesteps ,
102+ scores ,
103+ out_seq_len ,
104+ )
78105 else :
79- ctc_decode .paddle_beam_decode (probs , seq_lens , self ._labels , self ._num_labels , self ._beam_width ,
80- self ._num_processes ,
81- self ._cutoff_prob , self .cutoff_top_n , self ._blank_id , self ._log_probs ,
82- output , timesteps , scores , out_seq_len )
106+ ctc_decode .paddle_beam_decode (
107+ probs ,
108+ seq_lens ,
109+ self ._labels ,
110+ self ._num_labels ,
111+ self ._beam_width ,
112+ self ._num_processes ,
113+ self ._cutoff_prob ,
114+ self .cutoff_top_n ,
115+ self ._blank_id ,
116+ self ._log_probs ,
117+ output ,
118+ timesteps ,
119+ scores ,
120+ out_seq_len ,
121+ )
83122
84123 return output , scores , timesteps , out_seq_len
85124
@@ -99,3 +138,135 @@ def reset_params(self, alpha, beta):
99138 def __del__ (self ):
100139 if self ._scorer is not None :
101140 ctc_decode .paddle_release_scorer (self ._scorer )
141+
142+
143+ class OnlineCTCBeamDecoder (object ):
144+ """
145+ PyTorch wrapper for DeepSpeech PaddlePaddle Beam Search Decoder with interface for online decoding.
146+ Args:
147+ labels (list): The tokens/vocab used to train your model.
148+ They should be in the same order as they are in your model's outputs.
149+ model_path (basestring): The path to your external KenLM language model(LM)
150+ alpha (float): Weighting associated with the LMs probabilities.
151+ A weight of 0 means the LM has no effect.
152+ beta (float): Weight associated with the number of words within our beam.
153+ cutoff_top_n (int): Cutoff number in pruning. Only the top cutoff_top_n characters
154+ with the highest probability in the vocab will be used in beam search.
155+ cutoff_prob (float): Cutoff probability in pruning. 1.0 means no pruning.
156+ beam_width (int): This controls how broad the beam search is. Higher values are more likely to find top beams,
157+ but they also will make your beam search exponentially slower.
158+ num_processes (int): Parallelize the batch using num_processes workers.
159+ blank_id (int): Index of the CTC blank token (probably 0) used when training your model.
160+ log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1.
161+ """
162+ def __init__ (
163+ self ,
164+ labels ,
165+ model_path = None ,
166+ alpha = 0 ,
167+ beta = 0 ,
168+ cutoff_top_n = 40 ,
169+ cutoff_prob = 1.0 ,
170+ beam_width = 100 ,
171+ num_processes = 4 ,
172+ blank_id = 0 ,
173+ log_probs_input = False ,
174+ ):
175+ self ._cutoff_top_n = cutoff_top_n
176+ self ._beam_width = beam_width
177+ self ._scorer = None
178+ self ._num_processes = num_processes
179+ self ._labels = list (labels ) # Ensure labels are a list
180+ self ._num_labels = len (labels )
181+ self ._blank_id = blank_id
182+ self ._log_probs = 1 if log_probs_input else 0
183+ if model_path :
184+ self ._scorer = ctc_decode .paddle_get_scorer (
185+ alpha , beta , model_path .encode (), self ._labels , self ._num_labels
186+ )
187+ self ._cutoff_prob = cutoff_prob
188+
189+ def decode (self , probs , states , is_eos_s , seq_lens = None ):
190+ """
191+ Conducts the beamsearch on model outputs and return results.
192+ Args:
193+ probs (Tensor) - A rank 3 tensor representing model outputs. Shape is batch x num_timesteps x num_labels.
194+ states (Sequence[DecoderState]) - sequence of decoding states with lens equal to batch_size.
195+ is_eos_s (Sequence[bool]) - sequence of bool with lens equal to batch size.
196+ Should have False if havent pushed all chunks yet, and True if you pushed last cank and you want to get an answer
197+ seq_lens (Tensor) - A rank 1 tensor representing the sequence length of the items in the batch. Optional,
198+ if not provided the size of axis 1 (num_timesteps) of `probs` is used for all items
199+
200+ Returns:
201+ tuple: (beam_results, beam_scores, timesteps, out_lens)
202+
203+ beam_results (Tensor): A 3-dim tensor representing the top n beams of a batch of items.
204+ Shape: batchsize x num_beams x num_timesteps.
205+ Results are still encoded as ints at this stage.
206+ beam_scores (Tensor): A 3-dim tensor representing the likelihood of each beam in beam_results.
207+ Shape: batchsize x num_beams x num_timesteps
208+ timesteps (Tensor): A 2-dim tensor representing the timesteps at which the nth output character
209+ has peak probability.
210+ To be used as alignment between audio and transcript.
211+ Shape: batchsize x num_beams
212+ out_lens (Tensor): A 2-dim tensor representing the length of each beam in beam_results.
213+ Shape: batchsize x n_beams.
214+
215+ """
216+ probs = probs .cpu ().float ()
217+ batch_size , max_seq_len = probs .size (0 ), probs .size (1 )
218+ if seq_lens is None :
219+ seq_lens = torch .IntTensor (batch_size ).fill_ (max_seq_len )
220+ else :
221+ seq_lens = seq_lens .cpu ().int ()
222+ scores = torch .FloatTensor (batch_size , self ._beam_width ).cpu ().float ()
223+ out_seq_len = torch .zeros (batch_size , self ._beam_width ).cpu ().int ()
224+
225+ decode_fn = ctc_decode .paddle_beam_decode_with_given_state
226+ res_beam_results , res_timesteps = decode_fn (
227+ probs ,
228+ seq_lens ,
229+ self ._num_processes ,
230+ [state .state for state in states ],
231+ is_eos_s ,
232+ scores ,
233+ out_seq_len
234+ )
235+ res_beam_results = res_beam_results .int ()
236+ res_timesteps = res_timesteps .int ()
237+
238+ return res_beam_results , scores , res_timesteps , out_seq_len
239+
240+ def character_based (self ):
241+ return ctc_decode .is_character_based (self ._scorer ) if self ._scorer else None
242+
243+ def max_order (self ):
244+ return ctc_decode .get_max_order (self ._scorer ) if self ._scorer else None
245+
246+ def dict_size (self ):
247+ return ctc_decode .get_dict_size (self ._scorer ) if self ._scorer else None
248+
249+ def reset_state (state ):
250+ ctc_decode .paddle_release_state (state )
251+
252+
253+ class DecoderState :
254+ """
255+ Class using for maintain different chunks of data in one beam algorithm corresponding to one unique source.
256+ Note: after using State you should delete it, so dont reuse it
257+ Args:
258+ decoder (OnlineCTCBeamDecoder) - decoder you will use for decoding.
259+ """
260+ def __init__ (self , decoder ):
261+ self .state = ctc_decode .paddle_get_decoder_state (
262+ decoder ._labels ,
263+ decoder ._beam_width ,
264+ decoder ._cutoff_prob ,
265+ decoder ._cutoff_top_n ,
266+ decoder ._blank_id ,
267+ decoder ._log_probs ,
268+ decoder ._scorer ,
269+ )
270+
271+ def __del__ (self ):
272+ ctc_decode .paddle_release_state (self .state )
0 commit comments