Skip to content

Commit aaa5091

Browse files
committed
Changed [Frame| Feature] to [Request]
Signed-off-by: arushid <arushid@nvidia.com>
1 parent d040e6b commit aaa5091

File tree

2 files changed

+36
-35
lines changed

2 files changed

+36
-35
lines changed

nemo/collections/asr/inference/pipelines/cache_aware_ctc_pipeline.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_ctc_decoder import CTCGreedyDecoder
3030
from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_ctc_endpointing import CTCGreedyEndpointing
3131
from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer
32-
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame
32+
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request
3333
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
3434
from nemo.collections.asr.inference.streaming.state.cache_aware_ctc_state import CacheAwareCTCStreamingState
3535
from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames
@@ -214,17 +214,19 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N
214214
feature_buffers = torch.cat(feature_buffers).to(self.device)
215215
return feature_buffers, feature_buffer_lens
216216

217-
def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame | FeatureBuffer, log_probs: Tensor):
217+
def run_greedy_decoder(
218+
self, state: CacheAwareCTCStreamingState, request: Request, log_probs: Tensor
219+
):
218220
"""
219221
Run the greedy CTC decoder on the log_probs and update the state
220222
Args:
221223
state: (CacheAwareCTCStreamingState) The state of the stream
222-
frame: (Frame | FeatureBuffer) The current frame or feature buffer
223-
log_probs: (Tensor) The log probabilities of the current frame
224+
request: (Request) The current request (frame or feature buffer)
225+
log_probs: (Tensor) The log probabilities of the current request
224226
Returns:
225227
(bool) Whether EOU is detected.
226228
"""
227-
eou_detected = frame.is_last
229+
eou_detected = request.is_last
228230
last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id
229231
cur_output = self.greedy_ctc_decoder(log_probs, compute_confidence=True, previous=last_token)
230232
state.update_label_buffer(cur_output["labels"])
@@ -242,28 +244,28 @@ def run_greedy_decoder(self, state: CacheAwareCTCStreamingState, frame: Frame |
242244

243245
def decode_log_probs(
244246
self,
245-
frames: list[Frame | FeatureBuffer],
247+
requests: list[Request],
246248
log_probs: Tensor,
247249
tail_log_probs: Tensor | None,
248250
ready_state_ids: set,
249251
) -> None:
250252
"""
251253
Decode the log probabilities and update the state
252254
Args:
253-
frames: (list[Frame | FeatureBuffer]) List of frames or feature buffers to transcribe.
255+
requests: (list[Request]) List of requests (frames or feature buffers) to transcribe.
254256
log_probs: (Tensor) Log probabilities.
255257
tail_log_probs: (Tensor | None) Tail log probabilities.
256258
ready_state_ids: (set) Set of ready state IDs.
257259
"""
258260

259-
for idx, frame in enumerate(frames):
260-
state = self.get_state(frame.stream_id)
261-
eou_detected = self.run_greedy_decoder(state, frame, log_probs[idx])
261+
for idx, request in enumerate(requests):
262+
state = self.get_state(request.stream_id)
263+
eou_detected = self.run_greedy_decoder(state, request, log_probs[idx])
262264

263265
if eou_detected:
264266
self.bpe_decoder.decode_bpe_tokens(state)
265267
state.cleanup_after_eou()
266-
ready_state_ids.add(frame.stream_id)
268+
ready_state_ids.add(request.stream_id)
267269

268270
if tail_log_probs is not None:
269271
last_token = state.label_buffer[-1] if len(state.label_buffer) > 0 else self.blank_id
@@ -274,15 +276,15 @@ def decode_log_probs(
274276

275277
def cache_aware_transcribe_step(
276278
self,
277-
frames: list[Frame | FeatureBuffer],
279+
requests: list[Request],
278280
buffered_features: list[Tensor],
279281
right_paddings: list[int] | None,
280282
ready_state_ids: set,
281283
keep_all_outputs: bool = False,
282284
) -> None:
283285
"""
284286
Cache Aware Transcribe Step
285-
It receives a list of frames (Frame or FeatureBuffer) and features and do the following:
287+
It receives a list of requests (Frame or FeatureBuffer) and features and do the following:
286288
287289
1. Preprocess the features by stacking them and computing the lengths
288290
2. Get the context and mapping from the context manager for cache aware streaming
@@ -291,16 +293,16 @@ def cache_aware_transcribe_step(
291293
5. Decode the log probabilities and update the state
292294
293295
Args:
294-
frames: (list[Frame | FeatureBuffer]) List of frames or feature buffers to transcribe.
296+
requests: (list[Request]) List of requests (frames or feature buffers) to transcribe.
295297
buffered_features: (list[Tensor]) List of buffered features.
296298
right_paddings: (list[int] | None) List of right paddings.
297299
ready_state_ids: (set) Set of ready state IDs.
298300
keep_all_outputs: (bool) Whether to keep all outputs or not.
299301
"""
300302
feature_buffers, feature_buffer_lens = self.preprocess(buffered_features, right_paddings)
301303

302-
stream_ids = [frame.stream_id for frame in frames]
303-
eos_flags = [frame.is_last for frame in frames]
304+
stream_ids = [request.stream_id for request in requests]
305+
eos_flags = [request.is_last for request in requests]
304306
context, mapping = self.context_manager.get_context(stream_ids)
305307

306308
drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded
@@ -319,7 +321,7 @@ def cache_aware_transcribe_step(
319321
log_probs = normalize_log_probs(log_probs)
320322
self.context_manager.update_cache(stream_ids, new_context, mapping)
321323
self.context_manager.reset_slots(stream_ids, eos_flags)
322-
self.decode_log_probs(frames, log_probs, tail_log_probs, ready_state_ids)
324+
self.decode_log_probs(requests, log_probs, tail_log_probs, ready_state_ids)
323325

324326
def transcribe_step_for_frames(self, frames: list[Frame]) -> None:
325327
"""

nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from nemo.collections.asr.inference.streaming.decoders.greedy.greedy_rnnt_decoder import RNNTGreedyDecoder
3030
from nemo.collections.asr.inference.streaming.endpointing.greedy.greedy_rnnt_endpointing import RNNTGreedyEndpointing
3131
from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer
32-
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame
32+
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request
3333
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
3434
from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTStreamingState
3535
from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames
@@ -231,18 +231,18 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N
231231
return feature_buffers, feature_buffer_lens
232232

233233
def run_greedy_decoder(
234-
self, state: CacheAwareRNNTStreamingState, frame: Frame | FeatureBuffer, hyp: Hypothesis
234+
self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis
235235
) -> bool:
236236
"""
237237
Run the greedy RNNT decoder on the hypothesis and update the state
238238
Args:
239239
state: (CacheAwareRNNTStreamingState) The state of the stream
240-
frame: (Frame | FeatureBuffer) The current frame or feature buffer
241-
hyp: (Hypothesis) The hypothesis of the current frame
240+
request: (Request) The current request (frame or feature buffer)
241+
hyp: (Hypothesis) The hypothesis of the current request
242242
Returns:
243243
(bool) Whether EOU is detected.
244244
"""
245-
eou_detected = frame.is_last
245+
eou_detected = request.is_last
246246
cur_output, cur_labels, new_offset = self.greedy_rnnt_decoder(
247247
global_timestamps=hyp.timestamp,
248248
tokens=hyp.y_sequence,
@@ -266,15 +266,15 @@ def run_greedy_decoder(
266266

267267
def cache_aware_transcribe_step(
268268
self,
269-
frames: list[Frame | FeatureBuffer],
269+
requests: list[Request],
270270
features: list[Tensor],
271271
right_paddings: list[int],
272272
ready_state_ids: set,
273273
keep_all_outputs: bool = False,
274274
) -> None:
275275
"""
276276
Cache Aware Transcribe Step
277-
It receives a list of frames (Frame or FeatureBuffer) and features and do the following:
277+
It receives a list of requests (Frame or FeatureBuffer) and features and do the following:
278278
279279
1. Preprocess the features by stacking them and computing the lengths
280280
2. Collecting previous hypotheses for stateful decoding
@@ -285,7 +285,7 @@ def cache_aware_transcribe_step(
285285
7. Perform greedy RNNT decoding to get the best hypothesis and update the states
286286
8. Update the ready states to indicate that the state is ready for text post-processing
287287
Args:
288-
frames: (list[Frame | FeatureBuffer]) List of frames or feature buffers to transcribe.
288+
requests: (list[Request]) List of requests (frames or feature buffers) to transcribe.
289289
features: (list[Tensor]) List of feature buffers.
290290
right_paddings: (list[int] | None) List of right paddings.
291291
ready_state_ids: (set) Set of ready state IDs.
@@ -294,10 +294,10 @@ def cache_aware_transcribe_step(
294294

295295
feature_buffers, feature_buffer_lens = self.preprocess(features, right_paddings)
296296
states, stream_ids, eos_flags = [], [], []
297-
for frame in frames:
298-
states.append(self.get_state(frame.stream_id))
299-
stream_ids.append(frame.stream_id)
300-
eos_flags.append(frame.is_last)
297+
for request in requests:
298+
states.append(self.get_state(request.stream_id))
299+
stream_ids.append(request.stream_id)
300+
eos_flags.append(request.is_last)
301301

302302
previous_hypotheses = [state.get_previous_hypothesis() for state in states]
303303
context, mapping = self.context_manager.get_context(stream_ids)
@@ -324,20 +324,19 @@ def cache_aware_transcribe_step(
324324
self.context_manager.reset_slots(stream_ids, eos_flags)
325325

326326
# update the previous hypothesis and reset the previous hypothesis for the streams that has ended
327-
for i, (state, hyp, eos) in enumerate(zip(states, best_hyp, eos_flags)):
328-
hyp_len = len(hyp.y_sequence) if hyp is not None and hasattr(hyp, 'y_sequence') else 0
327+
for state, hyp, eos in zip(states, best_hyp, eos_flags):
329328
if eos:
330329
state.reset_previous_hypothesis()
331330
else:
332331
state.set_previous_hypothesis(hyp)
333332

334-
# run greedy decoder for each frame-state-hypothesis tuple
335-
for frame, state, hyp in zip(frames, states, best_hyp):
336-
eou_detected = self.run_greedy_decoder(state, frame, hyp)
333+
# run greedy decoder for each request-state-hypothesis tuple
334+
for request, state, hyp in zip(requests, states, best_hyp):
335+
eou_detected = self.run_greedy_decoder(state, request, hyp)
337336
if eou_detected:
338337
self.bpe_decoder.decode_bpe_tokens(state)
339338
state.cleanup_after_eou()
340-
ready_state_ids.add(frame.stream_id)
339+
ready_state_ids.add(request.stream_id)
341340

342341
def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None:
343342
"""

0 commit comments

Comments
 (0)