Skip to content

Commit ee6f0b5

Browse files
committed
fix multiprocessing in torch dataProvider
1 parent 23ecf0c commit ee6f0b5

File tree

6 files changed

+19
-7
lines changed

6 files changed

+19
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## [1.1.4] - 2022-09-29
2+
### Changed
3+
- Improoved `mltu.torch.dataProvider.DataProvider` to hangle `multiprocessing` when it doesn't work to switch to `multithreading`
4+
15
## [1.1.3] - 2022-09-29
26
### Changed
37
- Removed `Librosa` library dependency in requirements, now it is optional and required only with modules that use librosa
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
torch==1.13.1+cu117
1+
torch>=1.13.1+cu117
22
transformers==4.33.1
3-
onnx
3+
mltu==1.1.4
4+
onnx
5+
onnxruntime

Tutorials/10_wav2vec2_torch/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ def predict(self, audio: np.ndarray):
3939

4040
accum_cer.append(cer)
4141
accum_wer.append(wer)
42+
print(label)
4243

4344
pbar.set_description(f"Average CER: {np.average(accum_cer):.4f}, Average WER: {np.average(accum_wer):.4f}")

Tutorials/10_wav2vec2_torch/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def download_and_unzip(url, extract_to="Datasets", chunk_size=1024*1024):
6565
],
6666
transformers=[
6767
LabelIndexer(vocab),
68-
LabelPadding(max_word_length=configs.max_label_length, padding_value=len(vocab)),
6968
],
7069
use_cache=False,
7170
batch_postprocessors=[
72-
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True)
71+
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True),
72+
LabelPadding(padding_value=len(vocab), use_on_batch=True),
7373
],
7474
use_multiprocessing=True,
7575
max_queue_size=10,
76-
workers=64,
76+
workers=configs.train_workers,
7777
)
7878
train_dataProvider, test_dataProvider = data_provider.split(split=0.9)
7979

mltu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.1.3"
1+
__version__ = "1.1.4"
22

33
from .annotations.images import Image
44
from .annotations.images import CVImage

mltu/torch/dataProvider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,12 @@ def start_executor(self) -> None:
181181

182182
if not hasattr(self, "_executor"):
183183
if self.use_multiprocessing:
184-
self._executor = ProcessExecutor(self.process_data, self.workers)
184+
try:
185+
self._executor = ProcessExecutor(self.process_data, self.workers)
186+
except:
187+
self.use_multiprocessing = False
188+
self.logger.error("Failed to start multiprocessing, switching to multithreading")
189+
self._executor = ThreadExecutor(self.process_data, self.workers)
185190
else:
186191
self._executor = ThreadExecutor(self.process_data, self.workers)
187192

0 commit comments

Comments
 (0)