Skip to content

Commit 0c592a3

Browse files
authored
Merge pull request #6 from linagora-labs/monitoring (fix/improvements for nemo)
Various fixes and improvements and fixes mostly for nemo
2 parents 47f5eb1 + 231ee68 commit 0c592a3

23 files changed

+221
-493
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# See https://pre-commit.com for more information
22
# See https://pre-commit.com/hooks.html for more hooks
33
default_language_version:
4-
python: python3.10
4+
python: python3
55
repos:
66
- repo: https://github.com/charliermarsh/ruff-pre-commit
77
rev: v0.1.11

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ This repository focus on the following features:
3838
├── tools/ : Scripts to cope with audio data (data curation, ...)
3939
│   ├── kaldi/utils/ : Scripts to check and complete kaldi's data folders (.sh and .pl scripts)
4040
│   ├── LeVoiceLab/ : Scripts to convert data from/to LeVoiceLab format (see https://speech-data-hub.levoicelab.org/)
41+
│   ├── nemo/ : Scripts to manipulate, prepare and convert data to NeMo format
4142
│   └── scraping/ : Scripts to scrape a collection of documents (docx, pdf...) or the web
4243
├── docker/ : Docker environment
4344
└── tests/ : Unittest suite

ssak/utils/align_transcriptions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
load_model,
1414
)
1515
from ssak.utils.misc import hashmd5
16-
from ssak.utils.text_basic import transliterate
17-
from ssak.utils.text_basic import _punctuation
16+
from ssak.utils.text_basic import _punctuation, transliterate
1817
from ssak.utils.viewer import PlayWav
1918

2019
imshow_opts = dict(origin="upper", aspect="auto", vmax=0) # vmin = -25,

ssak/utils/audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def load_audio(path, start=None, end=None, sample_rate=16_000, mono=True, return
7575
audio, sr = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
7676
else:
7777
audio, sr = torchaudio.load(path)
78-
if return_format=="librosa":
78+
if return_format == "librosa":
7979
import librosa
80+
8081
offset = float(start if start else 0)
8182
duration = None
8283
if end:

ssak/utils/kaldi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def parse_line(line):
5858
}
5959

6060

61-
def check_kaldi_dir(dirname, language=None, strict_sort=False):
61+
def check_kaldi_dir(dirname, language=None, strict_sort=False, tool_dir=None):
6262
strict_sort = "true" if strict_sort else "false"
63-
tool_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), "tools", "kaldi", "utils")
64-
63+
if not tool_dir:
64+
tool_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), "tools", "kaldi", "utils")
6565
if os.path.isfile(os.path.join(dirname, "text")):
6666
with open(os.path.join(dirname, "text")) as f:
6767
texts = dict(parse_line(line) for line in f)

ssak/utils/kaldi_converter.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,17 @@ def merge_data(self, dataset, new_data):
113113
diff_a_b = set(dict_dataset.keys()).difference(set(dict_new_data.keys()))
114114
diff_b_a = set(dict_new_data.keys()).difference(set(dict_dataset.keys()))
115115
logger.warning(f"The data you are trying to merge have different lengths at step {self.__class__.__name__} (execute_order={self.execute_order})!")
116-
logger.warning(f"Dataset {len(dataset)} has {len(diff_a_b)} not present in new data")
117-
logger.warning(f"New data {len(new_data)} has {len(diff_b_a)} not present in dataset")
118-
logger.warning("Writing ids to debug.txt")
119-
with open("debug.txt", "w") as f:
120-
if len(diff_a_b) > 0:
121-
f.write("In datset but not in new data:\n")
122-
for i in diff_a_b:
123-
f.write(f"{i}\n")
124-
if len(diff_b_a) > 0:
116+
logger.warning(f"Dataset ({len(dataset)} rows) has {len(diff_a_b)} rows not present in new data")
117+
logger.warning(f"New data ({len(new_data)} rows) has {len(diff_b_a)} rows not present in dataset")
118+
logger.warning("Writing ids to log2kaldi/missing_ids.txt")
119+
os.makedirs("kaldi_data_processing", exist_ok=True)
120+
if len(diff_a_b) > 0:
121+
with open(os.path.join("kaldi_data_processing",f"merge_new_data_missing_{self.execute_order}_{self.__class__.__name__}.txt"), "w") as f:
122+
f.write("In dataset but not in new data:\n")
123+
for i in diff_a_b:
124+
f.write(f"{i}\n")
125+
if len(diff_b_a) > 0:
126+
with open(os.path.join("kaldi_data_processing",f"merge_dataset_missing_{self.execute_order}_{self.__class__.__name__}.txt"), "w") as f:
125127
f.write("In new data but not in dataset:\n")
126128
for i in diff_b_a:
127129
f.write(f"{i}\n")

ssak/utils/kaldi_dataset.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
LOG_FOLDER = "kaldi_data_processing"
1314

1415
@dataclass
1516
class KaldiDatasetRow:
@@ -247,6 +248,30 @@ def get_duration(self, mode=sum, target="segment"):
247248
return mode(durations)
248249
return mode([i.duration for i in self.dataset])
249250

251+
def check_if_segments_in_audios(self, acceptance_end_s=0.25):
252+
from pydub.utils import mediainfo
253+
254+
new_data = []
255+
removed_lines = []
256+
files_duration = dict()
257+
for row in tqdm(self, desc="Check if segments are in audios"):
258+
if row.audio_path not in files_duration:
259+
dur = round(float(mediainfo(row.audio_path)["duration"]), 3)
260+
files_duration[row.audio_path] = dur
261+
dur = files_duration[row.audio_path]
262+
if row.start >= dur:
263+
removed_lines.append(row)
264+
elif row.end > dur + acceptance_end_s:
265+
removed_lines.append(row)
266+
else:
267+
new_data.append(row)
268+
self.dataset = new_data
269+
logger.info(f"Removed {len(removed_lines)} segments that were not in audios (start or end after audio), check removed_lines_not_in_audios file")
270+
os.makedirs(LOG_FOLDER, exist_ok=True)
271+
with open(os.path.join(LOG_FOLDER, "removed_lines_not_in_audios"), "w") as f:
272+
for row in removed_lines:
273+
f.write(str(row) + "\n")
274+
250275
def filter_by_audio_ids(self, audio_ids):
251276
"""
252277
Filter the dataset by audio ids
@@ -281,7 +306,7 @@ def filter_by_speakers(self, speakers):
281306
new_dataset.append(row)
282307
return new_dataset
283308

284-
def normalize_dataset(self, apply_text_normalization=True):
309+
def normalize_dataset(self, apply_text_normalization=True, wer_format=False):
285310
"""
286311
Normalize the texts in the dataset using the format_text_latin function from ssak.utils.text_latin
287312
@@ -296,7 +321,7 @@ def normalize_dataset(self, apply_text_normalization=True):
296321
for row in tqdm(self.dataset, total=len(self.dataset), desc="Normalizing texts"):
297322
from ssak.utils.text_latin import format_text_latin
298323

299-
row.normalized_text = format_text_latin(row.text)
324+
row.normalized_text = format_text_latin(row.text, wer_format=wer_format)
300325
if apply_text_normalization:
301326
row.text = row.normalized_text
302327

@@ -357,7 +382,8 @@ def normalize_audios(self, output_wavs_conversion_folder, target_sample_rate=160
357382
else:
358383
removed_lines.append(row)
359384
self.dataset = new_dataset
360-
with open("removed_lines", "w") as f:
385+
os.makedirs(LOG_FOLDER, exist_ok=True)
386+
with open(os.path.join(LOG_FOLDER, "removed_lines_audio_empty"), "w") as f:
361387
for row in removed_lines:
362388
f.write(str(row) + "\n")
363389

@@ -571,7 +597,8 @@ def apply_filter(self, filter, filter_out=True):
571597
else:
572598
removed_lines.append(row)
573599
self.dataset = new_data
574-
with open("filtered_out", "w") as f:
600+
os.makedirs(LOG_FOLDER, exist_ok=True)
601+
with open(os.path.join(LOG_FOLDER, f"filtered_out_with_{filter.__name__ }"), "w") as f:
575602
for row in removed_lines:
576603
f.write(str(row) + "\n")
577604

ssak/utils/monitoring.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ class Monitoring:
398398

399399
def __init__(self, output_folder="", name="", interval=0.25, device="cuda", plot_monitoring=True, show_steps_in_plots=True):
400400
self.device = device
401+
self.device_name = None
401402
self.output_folder = output_folder
402403
if not name:
403404
self.name = output_folder
@@ -408,13 +409,25 @@ def __init__(self, output_folder="", name="", interval=0.25, device="cuda", plot
408409
self.will_plot_monitoring = plot_monitoring
409410
if self.will_plot_monitoring:
410411
pass
412+
self.device = self.device if self.device else 0
413+
if self.device=="cuda" or self.device == "gpu":
414+
self.device = 0
415+
elif self.device.startswith("cuda:"):
416+
self.device = int(self.device.split(":")[1])
417+
if self.device != "cpu" and isinstance(self.device, int):
418+
num_gpus = get_num_gpus()
419+
if self.device>num_gpus:
420+
raise ValueError(f"GPU {self.device} doesn't exist, only {num_gpus} GPUs available")
421+
self.device = ALL_GPU_INDICES[self.device]
422+
elif self.device != "cpu":
423+
raise ValueError(f"Device {self.device} doesn't exist, use 'gpu', 'cpu', 'cuda', 'cuda:0' or '0' for example")
411424

412425
def _finish_step(self, monitoring, step_values, step=0, start=0):
413426
for i in step_values:
414427
if i not in monitoring:
415428
monitoring[i] = []
416429
monitoring[i].extend(step_values[i])
417-
if self.steps and len(self.steps)>0:
430+
if self.steps and len(self.steps) > 0 and step < len(self.steps):
418431
if "steps" not in monitoring:
419432
monitoring["steps"] = []
420433
if "steps_end" not in monitoring:
@@ -464,9 +477,11 @@ def _monitor(self):
464477
start = time.time() - monitoring["time_points"][-1]
465478
if "device" in monitoring and monitoring["device"] != (pynvml.nvmlDeviceGetName(handle) if handle else "cpu"):
466479
raise ValueError("The device used in the monitoring is different from the one specified in the current monitoring")
480+
self.device_name = monitoring.get("device", "cpu")
467481
else:
468482
monitoring = dict()
469483
monitoring["device"] = pynvml.nvmlDeviceGetName(handle) if handle else "cpu"
484+
self.device_name = monitoring["device"]
470485
start = time.time()
471486
step = 0
472487
step_monitoring = dict()
@@ -498,12 +513,6 @@ def start(self, steps=None):
498513
steps: list of str
499514
List of steps to monitor
500515
"""
501-
self.device = self.device if self.device else 0
502-
if self.device == "cuda" or self.device == "gpu":
503-
self.device = 0
504-
if self.device != "cpu":
505-
get_num_gpus()
506-
self.device = ALL_GPU_INDICES[self.device]
507516
self.event_stop = threading.Event()
508517
self.event_next = threading.Event()
509518
self.event_error = threading.Event()
@@ -530,6 +539,18 @@ def stop(self, error=False):
530539
self.event_stop.set()
531540
self.monitoring_thread.join()
532541

542+
def get_device_name(self):
543+
if self.device_name is None:
544+
if self.device != "cpu":
545+
pynvml.nvmlInit()
546+
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device)
547+
else:
548+
handle = None
549+
self.device_name = pynvml.nvmlDeviceGetName(handle) if handle else "cpu"
550+
if handle:
551+
pynvml.nvmlShutdown()
552+
return self.device_name
553+
533554
def plot_hardware(self, values, times, output_folder, ylabel="RAM Usage", lims=None, steps=None):
534555
import matplotlib.pyplot as plt
535556

ssak/utils/text_latin.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33

44
from ssak.utils.text_basic import (
55
_punctuation,
6+
collapse_whitespace,
7+
format_special_characters,
8+
remove_punctuations,
9+
transliterate,
610
)
711
from ssak.utils.text_utils import (
812
numbers_and_symbols_to_letters,
913
regex_escape,
1014
remove_special_characters,
1115
)
12-
from ssak.utils.text_basic import (
13-
collapse_whitespace,
14-
format_special_characters,
15-
remove_punctuations,
16-
transliterate,
17-
)
1816

1917

2018
def _rm_key(d, key):
@@ -42,17 +40,7 @@ def find_acronyms(text, ignore_first_upper_words=True):
4240

4341

4442
def format_text_latin(
45-
text,
46-
lang="fr",
47-
lower_case=True,
48-
keep_punc=False,
49-
remove_ligatures=True,
50-
convert_numbers=True,
51-
extract_parenthesis=False,
52-
fid_acronyms=None,
53-
fid_special_chars=None,
54-
safety_checks=True,
55-
remove_suspicious_entry=False,
43+
text, lang="fr", lower_case=True, keep_punc=False, remove_ligatures=True, convert_numbers=True, extract_parenthesis=False, fid_acronyms=None, fid_special_chars=None, safety_checks=True, remove_suspicious_entry=False, wer_format=True
5644
):
5745
opts = _rm_key(locals(), "text")
5846

@@ -139,7 +127,10 @@ def format_text_latin(
139127
text = re.sub(":", " : ", text)
140128
text = re.sub(";", " ; ", text)
141129
# text = re.sub("^ *-+", "", text)
142-
text = re.sub("'", "' ", text)
130+
if wer_format:
131+
text = re.sub("'", "' ", text)
132+
else:
133+
text = re.sub("' ", "'", text)
143134
text = re.sub(r"\^+", "", text)
144135
text = re.sub(" +(- +)+", " ", text)
145136
text = re.sub("- ", " ", text)
@@ -171,7 +162,6 @@ def format_text_latin(
171162
# text_rep=split_h[0]+' heures '+split_h[1]
172163
text = text.replace(h, text_rep)
173164

174-
if convert_numbers:
175165
text = numbers_and_symbols_to_letters(text, lang=lang)
176166

177167
if lang == "fr":

ssak/utils/wer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def further_normalize(s):
155155

156156
if strong_normalization:
157157
from ssak.utils.text_basic import collapse_whitespace
158+
158159
def remove_not_words(s):
159160
# Remove any character that is not alpha-numeric (e.g. apostrophes, dashes, ...)
160161
return collapse_whitespace(re.sub(r"[^\w]", " ", s))
@@ -781,7 +782,9 @@ def func_ylabel(title, *args, **kwargs):
781782
plt.legend(
782783
fontsize=label_size,
783784
ncols=2,
784-
loc="best",
785+
# loc="best",
786+
loc='upper left',
787+
bbox_to_anchor=(1, 1)
785788
)
786789
if show_axisnames:
787790
use_percent = scale == 100

0 commit comments

Comments
 (0)