diff --git a/src/huggingface_hub/utils/tqdm.py b/src/huggingface_hub/utils/tqdm.py index 4c1fcef4be..586353de31 100644 --- a/src/huggingface_hub/utils/tqdm.py +++ b/src/huggingface_hub/utils/tqdm.py @@ -83,6 +83,7 @@ import io import logging import os +import threading import warnings from contextlib import contextmanager, nullcontext from pathlib import Path @@ -211,19 +212,48 @@ def is_tqdm_disabled(log_level: int) -> Optional[bool]: return None -class tqdm(old_tqdm): +class SafeDelLockMeta(type): + """ + Class for fixing `del tqdm_class._lock`: https://github.com/huggingface/datasets/issues/7660 + """ + + def __delattr__(cls, name): + if name == "_lock": + try: + super().__delattr__(name) + except AttributeError: + pass + else: + super().__delattr__(name) + + +class tqdm(old_tqdm, metaclass=SafeDelLockMeta): + # class tqdm(old_tqdm): """ Class to override `disable` argument in case progress bars are globally disabled. Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. """ + _lock = threading.RLock() # fallback, just in case + + @classmethod + def get_lock(cls): + if not hasattr(cls, "_lock") or cls._lock is None: + cls._lock = threading.RLock() + return cls._lock + def __init__(self, *args, **kwargs): name = kwargs.pop("name", None) # do not pass `name` to `tqdm` if are_progress_bars_disabled(name): kwargs["disable"] = True super().__init__(*args, **kwargs) + def update(self, n=1): + # Always get a valid lock + with self.get_lock(): + super().update(n) + def __delattr__(self, attr: str) -> None: """Fix for https://github.com/huggingface/huggingface_hub/issues/1603""" try: