Skip to content

Commit e503a6e

Browse files
nune-tadevosyanmonica-sekoyanmonica-sekoyannithinraokchtruong814
authored
Initial Chunking (#14321)
* adding nfa to canary Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * remove comments Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * modify external model loading Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * fix audio padding Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * reseting Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * handle non-possible alignment Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * add offset refinement Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * Initial Chunking Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Adding comments and docstrings Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Changes in doctrings Signed-off-by: Nune <ntadevosyan@nvidia.com> * Changes in doctrings Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Updates to the algrithm Signed-off-by: Nune <ntadevosyan@nvidia.com> * Update with timestamps Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Remove join_text Signed-off-by: Nune <ntadevosyan@nvidia.com> * Final Signed-off-by: Nune <ntadevosyan@nvidia.com> * Remove pdb Signed-off-by: Nune <ntadevosyan@nvidia.com> * Adjust timestamps Signed-off-by: Nune <ntadevosyan@nvidia.com> * Adjust timestamps Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Support for long audio Signed-off-by: Nune <ntadevosyan@nvidia.com> * Refactoring to keep model clean Signed-off-by: Nune <ntadevosyan@nvidia.com> * Small changes Signed-off-by: Nune <ntadevosyan@nvidia.com> * Removing changes from mixin Signed-off-by: Nune <ntadevosyan@nvidia.com> * small updates Signed-off-by: Nune <ntadevosyan@nvidia.com> * Back to main for mixin Signed-off-by: Nune <ntadevosyan@nvidia.com> * Fix for hypotheses Signed-off-by: Nune <ntadevosyan@nvidia.com> * Revert "Fix for hypotheses" This reverts commit 61fb893. Signed-off-by: Nune <ntadevosyan@nvidia.com> * Fix for hypotheses Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> * Revert "Revert "Fix for hypotheses"" This reverts commit 3c62a2d. Signed-off-by: Nune <ntadevosyan@nvidia.com> * Resolve Signed-off-by: Nune <ntadevosyan@nvidia.com> * Allowing user to control chunking Signed-off-by: Nune <ntadevosyan@nvidia.com> * Doc changes Signed-off-by: Nune <ntadevosyan@nvidia.com> * Forcing true for chunking Signed-off-by: Nune <ntadevosyan@nvidia.com> * Revert "reseting" This reverts commit 6d74ad0. Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * Revert "Apply isort and black reformatting" This reverts commit 1d8c363. Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * handle merge case for timestamps Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * add timestamp_type Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * add timestamps support chunked inference Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * refactor ctc timestamps to use utils Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * correct restore_token_cased with unk_token Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * use timestamps utils in rnnt_decoding Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * change external timestamps asr model loading Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * add forced aligned method tests Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * modify nfa to match new setup and utils Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * remove unused imports Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * merge conflicts Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * remove unused errors Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * remove unused import Signed-off-by: monica-sekoyan <msekoyan@vidia.com> * addressing comments, linting and flake8 Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * handle decode_ids_to_str change Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * correct usage of decode_tokens_to_str Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * update nfa docs Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * revert jupyter settings Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Merge and Tests Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * Unit tests Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * change decoding_tokens_to_str Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * change decoding_tokens_to_str Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * Update Signed-off-by: Nune <ntadevosyan@nvidia.com> * Doc updates Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * Doc updates Signed-off-by: Nune <ntadevosyan@nvidia.com> * Doc change for speech_to_text_aed_chunked_infer Signed-off-by: Nune <ntadevosyan@nvidia.com> * Remove some import Signed-off-by: Nune <ntadevosyan@nvidia.com> * Copyright Signed-off-by: Nune <ntadevosyan@nvidia.com> * Remove some import Signed-off-by: Nune <ntadevosyan@nvidia.com> * correct description Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * make private Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * rewrite restore_timestamps_asr_model Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * Update timestamps Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * Small updates Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * fix word offset logic Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> * Tests update after the fix Signed-off-by: Nune <ntadevosyan@nvidia.com> * Cases for monotonicity Signed-off-by: Nune <ntadevosyan@nvidia.com> * Apply isort and black reformatting Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> * Tests fix Signed-off-by: Nune <ntadevosyan@nvidia.com> * Increase L0_Unit_Tests_GPU_ASR timeout to 30 Signed-off-by: Charlie Truong <chtruong@nvidia.com> --------- Signed-off-by: Monica Sekoyan <msekoyan@nvidia.com> Signed-off-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> Signed-off-by: monica-sekoyan <msekoyan@nvidia.com> Signed-off-by: Nune <ntadevosyan@nvidia.com> Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Signed-off-by: monica-sekoyan <msekoyan@vidia.com> Signed-off-by: nune-tadevosyan <152167970+nune-tadevosyan@users.noreply.github.com> Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: Monica Sekoyan <msekoyan@nvidia.com> Co-authored-by: monica-sekoyan <monica-sekoyan@users.noreply.github.com> Co-authored-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com> Co-authored-by: monica-sekoyan <msekoyan@vidia.com> Co-authored-by: nithinraok <nithinrao.koluguri@gmail.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com>
1 parent 454fabc commit e503a6e

File tree

11 files changed

+889
-40
lines changed

11 files changed

+889
-40
lines changed

.github/workflows/cicd-main-speech.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
include:
3838
- script: L0_Unit_Tests_GPU_ASR
3939
runner: self-hosted-azure-gpus-1
40-
timeout: 20
40+
timeout: 30
4141
- script: L0_Unit_Tests_CPU_ASR
4242
runner: self-hosted-azure-cpu
4343
cpu-only: true

examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717
seconds and performs inference on each
1818
segment individually. The results are then concatenated to form the final output.
1919
20-
Below is an example of how to run this script with the Canary-1b model.
20+
Below is an example of how to run this script with the Canary-1b-v2 model.
2121
It's recommended to use manifest input, otherwise the model will perform English ASR
2222
with punctuations and capitalizations.
2323
An example manifest line:
2424
{
2525
"audio_filepath": "/path/to/audio.wav", # path to the audio file
2626
"duration": 10000.0, # duration of the audio
2727
"taskname": "asr", # use "s2t_translation" for AST
28-
"source_lang": "en", # Set `source_lang`==`target_lang` for ASR, choices=['en','de','es','fr']
29-
"target_lang": "de", # choices=['en','de','es','fr']
30-
"pnc": "yes", # whether to have PnC output, choices=['yes', 'no']
28+
"source_lang": "en", # Set `source_lang`==`target_lang` for ASR. Currently supported for 25 EU languages.
29+
"target_lang": "de", # See https://huggingface.co/nvidia/canary-1b-v2
3130
}
3231
3332
Example Usage:
@@ -41,8 +40,12 @@
4140
batch_size=16 \
4241
decoding.beam.beam_size=1
4342
44-
To return word and segment level timestamps, add `timestamps=True` to the above command,
45-
and set `chunk_len_in_secs=10.0` for best results.
43+
To return word and segment level timestamps, add `timestamps=True` to the above command.
44+
45+
Note: Canary-1b-v2 supports long‑form inference via the `.transcribe()` method.
46+
It will use dynamic chunking with overlapping windows for better performance.
47+
This behavior is enabled automatically for long‑form inference when transcribing a single
48+
audio file or when batch_size is set to 1.
4649
4750
"""
4851

nemo/collections/asr/data/audio_to_text_lhotse_prompted.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Callable, Optional, Union
15+
from typing import Optional, Union
1616

1717
import torch.utils.data
1818
from lhotse import CutSet
@@ -21,7 +21,7 @@
2121
from lhotse.dataset.collation import collate_vectors
2222

2323
from nemo.collections.common.data import apply_prompt_format_fn
24-
from nemo.collections.common.prompts import CanaryPromptFormatter, PromptFormatter
24+
from nemo.collections.common.prompts import PromptFormatter
2525
from nemo.collections.common.tokenizers import TokenizerSpec
2626

2727

@@ -61,22 +61,43 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
6161
Tokenized utterances will be extended with special prompt tokens according to ``prompt_format_fn`` logic.
6262
We support cuts with multiple supervision segments -- their tokenized texts will be concatenated before we add the prompt tokens.
6363
This is useful, for example, in code-switched scenarios where each segment is spoken in a different language.
64+
65+
Chunking:
66+
If `enable_chunking` is True, each audio sample is split into optimally sized chunks
67+
(see `find_optimal_chunk_size` and `chunk_waveform`). This is useful for long audio inputs,
68+
allowing the model to process them in manageable segments.
6469
"""
6570

6671
def __init__(
6772
self,
6873
tokenizer: TokenizerSpec,
6974
prompt: PromptFormatter,
75+
enable_chunking: bool = False,
7076
):
7177
super().__init__()
7278
self.tokenizer = tokenizer
7379
self.load_audio = AudioSamples(fault_tolerant=True)
7480
self.padding_value = self.tokenizer.pad_id
7581
self.prompt = prompt
82+
self.enable_chunking = enable_chunking
7683

7784
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
7885
audio, audio_lens, cuts = self.load_audio(cuts)
7986

87+
# Will work if batch_size is set to 1.
88+
if self.enable_chunking:
89+
# If dynamic chunking is enabled, split each audio sample into chunks.
90+
new_audio = []
91+
new_audio_lens = []
92+
for i in range(audio.shape[0]):
93+
waveform = audio[i, : audio_lens[i]]
94+
# Split the waveform into chunks and get their lengths.
95+
chunks, chunk_lens = self._chunk_waveform(waveform)
96+
new_audio.extend(chunks)
97+
new_audio_lens.extend(chunk_lens)
98+
# Stack all chunks into a batch.
99+
audio = torch.stack(new_audio)
100+
audio_lens = torch.tensor(new_audio_lens, dtype=torch.long)
80101
# Fast-path: the tokenization and prompt formatting was already done before sampling.
81102
attrs = ("input_ids", "context_ids", "answer_ids")
82103
pre_formatted = all(hasattr(c, a) for c in cuts for a in attrs)
@@ -110,6 +131,93 @@ def _collate_tokens(self, tokens: list[Union[list[int], torch.Tensor]]) -> tuple
110131
tokens = collate_vectors(tokens, padding_value=self.padding_value)
111132
return tokens, token_lens
112133

134+
def _find_optimal_chunk_size(
135+
self, total_len: int, min_sec: int = 30, max_sec: int = 40, sample_rate: int = 16000, overlap_sec: float = 1.0
136+
) -> int:
137+
"""
138+
Find the optimal chunk size for audio processing that minimizes paddings to the last chunk.
139+
140+
Args:
141+
total_len (int): Total length of the audio waveform in samples
142+
min_sec (int, optional): Minimum chunk size in seconds. Defaults to 30.
143+
max_sec (int, optional): Maximum chunk size in seconds. Defaults to 40.
144+
sample_rate (int, optional): Audio sample rate in Hz. Defaults to 16000.
145+
overlap_sec (float, optional): Overlap duration between consecutive chunks in seconds.
146+
Defaults to 1.0.
147+
148+
Returns:
149+
int: Optimal chunk size in samples that maximizes the last chunk length
150+
"""
151+
best_chunk_size = min_sec * sample_rate
152+
best_last_chunk_len = 0
153+
if total_len < max_sec * sample_rate:
154+
return total_len
155+
# Try each possible chunk duration in the range
156+
for sec in range(min_sec, max_sec + 1):
157+
chunk_size = sec * sample_rate
158+
overlap_size = int(overlap_sec * sample_rate)
159+
step_size = chunk_size - overlap_size
160+
161+
if step_size <= 0: # Invalid overlap
162+
continue
163+
if chunk_size > total_len:
164+
continue
165+
166+
# Calculate how many chunks we'd need and the last chunk's length
167+
n_chunks = (total_len + step_size - 1) // step_size
168+
last_chunk_len = total_len - step_size * (n_chunks - 1)
169+
170+
if last_chunk_len > best_last_chunk_len:
171+
best_last_chunk_len = last_chunk_len
172+
best_chunk_size = chunk_size
173+
174+
return best_chunk_size
175+
176+
def _chunk_waveform(
177+
self, waveform: torch.Tensor, chunk_size: int = None, overlap_sec: float = 1.0, sample_rate: int = 16000
178+
) -> tuple[list[torch.Tensor], list[int]]:
179+
"""
180+
Split a waveform tensor into overlapping chunks.
181+
182+
Args:
183+
waveform (torch.Tensor): Input audio waveform tensor of shape (time_samples,)
184+
chunk_size (int, optional): Size of each chunk in samples. If None, automatically
185+
determines optimal chunk size using find_optimal_chunk_size().
186+
Defaults to None.
187+
sample_rate (int, optional): Audio sample rate in Hz. Defaults to 16000.
188+
overlap_sec (float, optional): Overlap duration between consecutive chunks in seconds.
189+
Used to calculate step size. Defaults to 2.
190+
191+
Returns:
192+
tuple[list[torch.Tensor], list[int]]: A tuple containing:
193+
- List of chunk tensors, each of shape (chunk_size,)
194+
- List of original lengths for each chunk before padding (useful for masking
195+
padded regions during processing.
196+
"""
197+
# If chunk_size is None, find the optimal chunk size for this waveform
198+
total_len = waveform.shape[0]
199+
if chunk_size is None:
200+
chunk_size = self._find_optimal_chunk_size(total_len, overlap_sec=overlap_sec)
201+
if chunk_size <= total_len:
202+
return [waveform], [total_len]
203+
overlap_size = int(overlap_sec * sample_rate)
204+
step_size = chunk_size - overlap_size
205+
chunks = []
206+
chunk_lens = []
207+
start = 0
208+
while start + overlap_size < total_len:
209+
end = min(start + chunk_size, total_len)
210+
chunk = waveform[start:end]
211+
length = chunk.shape[0]
212+
if length < chunk_size:
213+
pad = torch.zeros(chunk_size - length, dtype=chunk.dtype, device=chunk.device)
214+
chunk = torch.cat([chunk, pad], dim=0)
215+
chunks.append(chunk)
216+
chunk_lens.append(length)
217+
start += step_size
218+
219+
return chunks, chunk_lens
220+
113221

114222
class ProbablyIncorrectLanguageKeyError(RuntimeError):
115223
pass

nemo/collections/asr/models/aed_multitask_models.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
4141
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
4242
from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier
43+
from nemo.collections.asr.parts.utils.chunking_utils import merge_all_hypotheses, merge_parallel_chunks
4344
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
4445
from nemo.collections.asr.parts.utils.timestamp_utils import (
4546
get_forced_aligned_timestamps_with_external_model,
@@ -110,6 +111,10 @@ class MultiTaskTranscriptionInternalConfig(InternalTranscribeConfig):
110111
class MultiTaskTranscriptionConfig(TranscribeConfig):
111112
"""
112113
Configuration for Multi Task Transcription
114+
115+
enable_chunking: bool = True
116+
Whether to enable parallel processing of audio chunks for long-form audio.
117+
If enabled, batch_size should be set to 1 or single audio be passed.
113118
"""
114119

115120
prompt: list[dict[str, dict[str, str]]] | None = None
@@ -119,6 +124,7 @@ class MultiTaskTranscriptionConfig(TranscribeConfig):
119124
_internal: Optional[MultiTaskTranscriptionInternalConfig] = field(
120125
default_factory=lambda: MultiTaskTranscriptionInternalConfig()
121126
)
127+
enable_chunking: bool = True
122128

123129
def __post_init__(self):
124130
self.prompt = parse_multitask_prompt(self.prompt)
@@ -495,6 +501,7 @@ def transcribe(
495501
) -> Union[List[str], List[Hypothesis]]:
496502
"""
497503
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
504+
This allows the model to process long audio in manageable chunks and merge the results.
498505
Args:
499506
audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path
500507
to a manifest file.
@@ -525,7 +532,7 @@ def transcribe(
525532
526533
Returns:
527534
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order
528-
as paths2audio_files
535+
as paths2audio_files
529536
"""
530537
if timestamps is not None:
531538
if self.timestamps_asr_model is None:
@@ -561,22 +568,43 @@ def transcribe(
561568
trcfg = override_config
562569
trcfg.timestamps = timestamps
563570

564-
return super().transcribe(audio=audio, override_config=trcfg)
571+
if trcfg.enable_chunking:
572+
# Check if only one audio is provided with string
573+
is_one_audio = isinstance(audio, str) and not (audio.endswith("json") or audio.endswith("jsonl"))
574+
# Check if it is provided as a list of strings
575+
is_one_audio = is_one_audio or (isinstance(audio, list) and len(audio) == 1)
576+
# Check if chunking will be enabled
577+
trcfg.enable_chunking = is_one_audio or (override_config is not None and override_config.batch_size == 1)
578+
if not trcfg.enable_chunking:
579+
logging.warning("Chunking is disabled. Please pass a single audio file or set batch_size to 1")
580+
581+
results = super().transcribe(audio=audio, override_config=trcfg)
582+
if trcfg.enable_chunking:
583+
results = merge_all_hypotheses(results, trcfg.timestamps, self.encoder.subsampling_factor)
584+
585+
return results
565586

566587
def _setup_dataloader_from_config(self, config: Optional[Dict]):
588+
567589
assert config.get("use_lhotse", False), (
568590
"Multi-task model only supports dataloading with Lhotse. "
569591
"Please set config.{train,validation,test}_ds.use_lhotse=True"
570592
)
571593
global_rank = config.get("global_rank", self.global_rank)
572594
world_size = config.get("world_size", self.world_size)
595+
enable_chunking = config.get("enable_chunking", False)
596+
if enable_chunking:
597+
# Adding this to support processing audio files of arbitrary length by chunking them into hour-long segments.
598+
config.cut_into_windows_duration = 3600
599+
config.cut_into_windows_hop = 3600
573600
return get_lhotse_dataloader_from_config(
574601
config,
575602
global_rank=global_rank,
576603
world_size=world_size,
577604
dataset=PromptedAudioToTextLhotseDataset(
578605
tokenizer=self.tokenizer,
579606
prompt=self.prompt,
607+
enable_chunking=enable_chunking, # <-- enables chunking
580608
),
581609
tokenizer=self.tokenizer,
582610
)
@@ -889,10 +917,12 @@ def _transcribe_input_manifest_processing(
889917
A config dict that is used to setup the dataloader for transcription.
890918
"""
891919
manifest_filepath = trcfg._internal.manifest_filepath
892-
893920
audio_files = self._may_be_make_dict_and_fix_paths(audio_files, manifest_filepath, trcfg)
894921

895-
return super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg)
922+
ds_config = super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg)
923+
if trcfg.enable_chunking:
924+
ds_config['enable_chunking'] = True
925+
return ds_config
896926

897927
def _transcribe_forward(
898928
self, batch: PromptedAudioToTextMiniBatch | tuple[torch.Tensor, ...], trcfg: MultiTaskTranscriptionConfig
@@ -979,6 +1009,8 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
9791009
"""
9801010
Internal function to process the model's outputs to return the results to the user. This function is called by
9811011
`transcribe()` and `transcribe_generator()` to process the model's outputs.
1012+
If parallel chunking was used (enable_chunking=True), merges the hypotheses from each chunk
1013+
into a single hypothesis, joining text, token sequences, and timestamps.
9821014
9831015
Args:
9841016
outputs: The model's outputs that are processed by `_transcribe_forward()`.
@@ -988,6 +1020,7 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
9881020
The output can be a list of
9891021
objects, list of list of objects.
9901022
Its type is defined in `TranscriptionReturnType`.
1023+
9911024
"""
9921025
log_probs = outputs.pop('log_probs')
9931026
encoded_len = outputs.pop('encoded_lengths')
@@ -996,14 +1029,18 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
9961029
decoder_input_ids = outputs.pop('decoder_input_ids')
9971030
batch = outputs.pop('batch')
9981031

999-
del log_probs, encoded_len
1000-
1032+
del log_probs
1033+
num_chunks = enc_states.shape[0]
1034+
# Repear decoder_input_ids to match number of chunks
1035+
if trcfg.enable_chunking and num_chunks > decoder_input_ids.shape[0]:
1036+
decoder_input_ids = decoder_input_ids.repeat(num_chunks, 1)
10011037
hypotheses = self.decoding.decode_predictions_tensor(
10021038
encoder_hidden_states=enc_states,
10031039
encoder_input_mask=enc_mask,
10041040
decoder_input_ids=decoder_input_ids,
10051041
return_hypotheses=trcfg.return_hypotheses,
10061042
)
1043+
merge_to_be_done = trcfg.enable_chunking and len(hypotheses) > 1
10071044

10081045
del enc_states, enc_mask, decoder_input_ids
10091046

@@ -1013,13 +1050,29 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
10131050
batch_size=len(batch.audio),
10141051
external_ctc_model=self.timestamps_asr_model,
10151052
main_model_predictions=hypotheses,
1016-
timestamp_type=['word', 'segment'],
1053+
timestamp_type='char' if merge_to_be_done else ['word', 'segment'],
10171054
viterbi_device=trcfg._internal.device,
10181055
)
10191056
elif trcfg.timestamps:
10201057
hypotheses = process_aed_timestamp_outputs(
10211058
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
10221059
)
1060+
if merge_to_be_done:
1061+
merged_hypotheses = merge_parallel_chunks(
1062+
hypotheses=hypotheses,
1063+
encoded_len=encoded_len,
1064+
model=self,
1065+
timestamps=trcfg.timestamps,
1066+
subsampling_factor=self.encoder.subsampling_factor,
1067+
window_stride=self.cfg['preprocessor']['window_stride'],
1068+
decoding=self.decoding,
1069+
)
1070+
# Inject the id of the cut to hypothese to later be used for separate batches
1071+
setattr(merged_hypotheses, 'id', batch.cuts[0].id.split("-", 1)[0])
1072+
return [merged_hypotheses]
1073+
1074+
if trcfg.enable_chunking and len(hypotheses) == 1:
1075+
setattr(hypotheses[0], 'id', batch.cuts[0].id.split("-", 1)[0])
10231076
return hypotheses
10241077

10251078
def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
@@ -1035,6 +1088,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
10351088
stored.
10361089
Returns:
10371090
A pytorch DataLoader for the given audio file(s).
1091+
10381092
"""
10391093
if 'manifest_filepath' in config:
10401094
manifest_filepath = config['manifest_filepath']
@@ -1059,6 +1113,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
10591113
'channel_selector': config.get('channel_selector', None),
10601114
'pad_min_duration': config.get('pad_min_duration', 1.0),
10611115
'pad_direction': config.get('pad_direction', 'both'),
1116+
'enable_chunking': config.get('enable_chunking', False),
10621117
}
10631118

10641119
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))

0 commit comments

Comments
 (0)