Skip to content

Commit 0402d4b

Browse files
committed
add ml normalizer
Signed-off-by: nithinraok <[email protected]>
1 parent e3c3d4d commit 0402d4b

File tree

5 files changed

+25
-18
lines changed

5 files changed

+25
-18
lines changed

nemo_asr/run_eval_ml.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,6 @@
1717

1818
wer_metric = evaluate.load("wer")
1919

20-
def normalize_text(text):
21-
"""Simple text normalization for non english languages"""
22-
if text is None:
23-
return ""
24-
# Remove capitalization
25-
text = text.lower()
26-
27-
# Remove punctuation
28-
text = re.sub(r'[^\w\s]', '', text)
29-
30-
# Remove extra spaces
31-
text = re.sub(r'\s+', ' ', text).strip()
32-
return text
3320

3421
def main(args):
3522
DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
@@ -181,8 +168,8 @@ def download_audio_files(batch):
181168
transcriptions = transcriptions[0]
182169

183170
references = all_data["references"]
184-
references = [normalize_text(ref) for ref in references]
185-
predictions = [normalize_text(pred.text) for pred in transcriptions]
171+
references = [data_utils.ml_normalizer(ref) for ref in references]
172+
predictions = [data_utils.ml_normalizer(pred.text) for pred in transcriptions]
186173

187174
avg_time = total_time / len(all_data["audio_filepaths"])
188175

nemo_asr/run_nemo_ml.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ export PYTHONPATH="..":$PYTHONPATH
99
# Configuration
1010
MODEL_ID="nvidia/canary-1b-flash"
1111

12-
BATCH_SIZE=128
12+
BATCH_SIZE=64
13+
1314
DEVICE_ID=0
1415

1516
# Available datasets and languages

normalizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .normalizer import EnglishTextNormalizer
1+
from .normalizer import EnglishTextNormalizer, BasicMultilingualTextNormalizer

normalizer/data_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datasets import load_dataset, Audio
2-
from normalizer import EnglishTextNormalizer
2+
from normalizer import EnglishTextNormalizer, BasicMultilingualTextNormalizer
33

44
from .eval_utils import read_manifest, write_manifest
55

@@ -30,6 +30,8 @@ def get_text(sample):
3030

3131
normalizer = EnglishTextNormalizer()
3232

33+
ml_normalizer = BasicMultilingualTextNormalizer()
34+
3335

3436
def normalize(batch):
3537
batch["original_text"] = get_text(batch)

normalizer/normalizer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,23 @@ def __call__(self, s: str):
9292
return s
9393

9494

95+
class BasicMultilingualTextNormalizer:
96+
def __init__(self, remove_diacritics: bool = True):
97+
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
98+
99+
def __call__(self, s: str):
100+
s = s.lower()
101+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
102+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
103+
s = self.clean(s).lower()
104+
105+
# Remove punctuations and extra spaces
106+
s = re.sub(r"[^\w\s]", "", s)
107+
s = re.sub(r"\s+", " ", s).strip()
108+
109+
return s
110+
111+
95112
class EnglishNumberNormalizer:
96113
"""
97114
Convert any spelled-out numbers into arabic numbers, while handling:

0 commit comments

Comments
 (0)