Skip to content

Commit 3dfcf67

Browse files
Apply isort and black reformatting
Signed-off-by: nune-tadevosyan <nune-tadevosyan@users.noreply.github.com>
1 parent 0284ab0 commit 3dfcf67

File tree

4 files changed

+206
-194
lines changed

4 files changed

+206
-194
lines changed

nemo/collections/asr/models/aed_multitask_models.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
4444
from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig
4545
from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier
46+
from nemo.collections.asr.parts.utils.chunking_utils import merge_hypotheses_list, merge_parallel_chunks
4647
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis
4748
from nemo.collections.asr.parts.utils.timestamp_utils import (
4849
get_forced_aligned_timestamps_with_external_model,
@@ -70,10 +71,7 @@
7071
)
7172
from nemo.utils import logging, model_utils
7273
from nemo.utils.app_state import AppState
73-
from nemo.collections.asr.parts.utils.chunking_utils import (
74-
merge_parallel_chunks,
75-
merge_hypotheses_list
76-
)
74+
7775
__all__ = ['EncDecMultiTaskModel']
7876

7977

@@ -119,8 +117,8 @@ class MultiTaskTranscriptionConfig(TranscribeConfig):
119117
"""
120118
Configuration for Multi Task Transcription
121119
122-
enable_parallel_chunking: bool = False
123-
Whether to enable parallel processing of audio chunks for long-form audio.
120+
enable_parallel_chunking: bool = False
121+
Whether to enable parallel processing of audio chunks for long-form audio.
124122
It will be automatically enabled for batch size 1.
125123
"""
126124

@@ -131,7 +129,7 @@ class MultiTaskTranscriptionConfig(TranscribeConfig):
131129
_internal: Optional[MultiTaskTranscriptionInternalConfig] = field(
132130
default_factory=lambda: MultiTaskTranscriptionInternalConfig()
133131
)
134-
enable_parallel_chunking: bool = False
132+
enable_parallel_chunking: bool = False
135133

136134
def __post_init__(self):
137135
self.prompt = parse_multitask_prompt(self.prompt)
@@ -573,7 +571,7 @@ def transcribe(
573571
)
574572
trcfg = override_config
575573
trcfg.timestamps = timestamps
576-
# Check if only one audio is provided with string
574+
# Check if only one audio is provided with string
577575
is_one_audio = isinstance(audio, str) and not (audio.endswith("json") or audio.endswith("jsonl"))
578576
# Check if it is provided as a list of strings
579577
is_one_audio = is_one_audio or (isinstance(audio, list) and len(audio) == 1)
@@ -1004,8 +1002,6 @@ def _transcribe_forward(
10041002
batch=batch,
10051003
)
10061004

1007-
1008-
10091005
def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType:
10101006
"""
10111007
Internal function to process the model's outputs to return the results to the user. This function is called by
@@ -1058,23 +1054,23 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo
10581054
hypotheses = process_aed_timestamp_outputs(
10591055
hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
10601056
)
1061-
1062-
if merge_to_be_done:
1057+
1058+
if merge_to_be_done:
10631059
merged_hypotheses = merge_parallel_chunks(
10641060
hypotheses=hypotheses,
10651061
encoded_len=encoded_len,
10661062
model=self,
10671063
subsampling_factor=self.encoder.subsampling_factor,
10681064
window_stride=self.cfg['preprocessor']['window_stride'],
1069-
tokenizer=self.tokenizer
1065+
tokenizer=self.tokenizer,
10701066
)
1071-
#Inject the id of the cut to hypothese to later be used for separate batches
1067+
# Inject the id of the cut to hypothese to later be used for separate batches
10721068
setattr(merged_hypotheses, 'id', batch.cuts[0].id.split("-", 1)[0])
10731069
return [merged_hypotheses]
1074-
1075-
if trcfg.enable_parallel_chunking and len(hypotheses) == 1:
1076-
setattr(hypotheses[0], 'id', batch.cuts[0].id.split("-", 1)[0])
1077-
1070+
1071+
if trcfg.enable_parallel_chunking and len(hypotheses) == 1:
1072+
setattr(hypotheses[0], 'id', batch.cuts[0].id.split("-", 1)[0])
1073+
10781074
return hypotheses
10791075

10801076
def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':

nemo/collections/asr/parts/mixins/transcription.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig
382382
"""
383383
Transcribe Execution Flow
384384
"""
385-
386385

387386
def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
388387
"""

nemo/collections/asr/parts/utils/aligner_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def restore_token_case(word: str, word_tokens: List[str]) -> List[str]:
113113
while "__" in word:
114114
word = word.replace("__", "_")
115115

116-
# while " " in word:
117-
# word = word.replace(" ", "")
116+
# while " " in word:
117+
# word = word.replace(" ", "")
118118

119119
word_tokens_cased = []
120120
word_char_pointer = 0
@@ -135,7 +135,11 @@ def restore_token_case(word: str, word_tokens: List[str]) -> List[str]:
135135
word_char_pointer += 1
136136
else:
137137
if token_char == "▁" or token_char == "_":
138-
if word[word_char_pointer] == "▁" or word[word_char_pointer] == "_" or word[word_char_pointer] == " ":
138+
if (
139+
word[word_char_pointer] == "▁"
140+
or word[word_char_pointer] == "_"
141+
or word[word_char_pointer] == " "
142+
):
139143
token_cased += token_char
140144
word_char_pointer += 1
141145
elif word_char_pointer == 0:

0 commit comments

Comments
 (0)