diff --git a/src/transformers/models/bart/tokenization_bart.py b/src/transformers/models/bart/tokenization_bart.py index f674afe1a412..91971a29191c 100644 --- a/src/transformers/models/bart/tokenization_bart.py +++ b/src/transformers/models/bart/tokenization_bart.py @@ -18,10 +18,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/bertweet/tokenization_bertweet.py b/src/transformers/models/bertweet/tokenization_bertweet.py index 3ce1a3182bf9..94e2a7240d1c 100644 --- a/src/transformers/models/bertweet/tokenization_bertweet.py +++ b/src/transformers/models/bertweet/tokenization_bertweet.py @@ -21,10 +21,9 @@ from shutil import copyfile from typing import Optional -import regex - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex logger = logging.get_logger(__name__) diff --git a/src/transformers/models/blenderbot/tokenization_blenderbot.py b/src/transformers/models/blenderbot/tokenization_blenderbot.py index 76719fa25494..4cdcdde8b016 100644 --- a/src/transformers/models/blenderbot/tokenization_blenderbot.py +++ b/src/transformers/models/blenderbot/tokenization_blenderbot.py @@ -19,10 +19,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py index adb54025ce23..91fa8ffb1dda 100644 --- a/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/tokenization_blenderbot_small.py @@ -18,10 +18,9 @@ import os from typing import Optional -import regex as re - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/clip/tokenization_clip.py b/src/transformers/models/clip/tokenization_clip.py index 625d26dc6960..7b1a52224423 100644 --- a/src/transformers/models/clip/tokenization_clip.py +++ b/src/transformers/models/clip/tokenization_clip.py @@ -20,10 +20,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/clvp/tokenization_clvp.py b/src/transformers/models/clvp/tokenization_clvp.py index 4b0b285561c5..03682e693cbd 100644 --- a/src/transformers/models/clvp/tokenization_clvp.py +++ b/src/transformers/models/clvp/tokenization_clvp.py @@ -19,10 +19,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re from .number_normalizer import EnglishNormalizer diff --git a/src/transformers/models/codegen/tokenization_codegen.py b/src/transformers/models/codegen/tokenization_codegen.py index 4d08c6acd5bb..5906ae77e9f2 100644 --- a/src/transformers/models/codegen/tokenization_codegen.py +++ b/src/transformers/models/codegen/tokenization_codegen.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING, Optional, Union import numpy as np -import regex as re from ...utils import logging, to_py_obj +from ...utils.safe import regex as re if TYPE_CHECKING: diff --git a/src/transformers/models/ctrl/tokenization_ctrl.py b/src/transformers/models/ctrl/tokenization_ctrl.py index 5b7935e6404d..0a2c0eff46c7 100644 --- a/src/transformers/models/ctrl/tokenization_ctrl.py +++ b/src/transformers/models/ctrl/tokenization_ctrl.py @@ -18,10 +18,9 @@ import os from typing import Optional -import regex as re - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/deberta/tokenization_deberta.py b/src/transformers/models/deberta/tokenization_deberta.py index 74e958c8030b..858c35bd0870 100644 --- a/src/transformers/models/deberta/tokenization_deberta.py +++ b/src/transformers/models/deberta/tokenization_deberta.py @@ -18,10 +18,9 @@ import os from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/deepseek_vl/convert_deepseek_vl_weights_to_hf.py b/src/transformers/models/deepseek_vl/convert_deepseek_vl_weights_to_hf.py index 3e9b6a37fe09..4c4f2a8c6df6 100644 --- a/src/transformers/models/deepseek_vl/convert_deepseek_vl_weights_to_hf.py +++ b/src/transformers/models/deepseek_vl/convert_deepseek_vl_weights_to_hf.py @@ -18,7 +18,6 @@ import os from typing import Optional -import regex as re import torch from accelerate import init_empty_weights from huggingface_hub import snapshot_download @@ -33,6 +32,7 @@ DeepseekVLProcessor, ) from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.utils.safe import regex as re # fmt: off diff --git a/src/transformers/models/deepseek_vl_hybrid/convert_deepseek_vl_hybrid_weights_to_hf.py b/src/transformers/models/deepseek_vl_hybrid/convert_deepseek_vl_hybrid_weights_to_hf.py index 9f377a53c8f3..108a66fae9d4 100644 --- a/src/transformers/models/deepseek_vl_hybrid/convert_deepseek_vl_hybrid_weights_to_hf.py +++ b/src/transformers/models/deepseek_vl_hybrid/convert_deepseek_vl_hybrid_weights_to_hf.py @@ -18,7 +18,6 @@ import os from typing import Optional -import regex as re import torch from accelerate import init_empty_weights from huggingface_hub import snapshot_download @@ -39,6 +38,7 @@ OPENAI_CLIP_STD, PILImageResampling, ) +from transformers.utils.safe import regex as re # fmt: off diff --git a/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py b/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py index b9948c5c354e..bfaba58a4bae 100644 --- a/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/tokenization_jukebox.py @@ -22,12 +22,12 @@ from typing import Any, Optional, Union import numpy as np -import regex from ....tokenization_utils import AddedToken, PreTrainedTokenizer from ....tokenization_utils_base import BatchEncoding from ....utils import TensorType, is_torch_available, logging from ....utils.generic import is_numpy_array +from ...utils.safe import regex logger = logging.get_logger(__name__) diff --git a/src/transformers/models/deprecated/tapex/tokenization_tapex.py b/src/transformers/models/deprecated/tapex/tokenization_tapex.py index fa74d8aa3b55..53f43a36be18 100644 --- a/src/transformers/models/deprecated/tapex/tokenization_tapex.py +++ b/src/transformers/models/deprecated/tapex/tokenization_tapex.py @@ -20,12 +20,11 @@ from functools import lru_cache from typing import Optional, Union -import regex as re - from ....file_utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available from ....tokenization_utils import AddedToken, PreTrainedTokenizer from ....tokenization_utils_base import ENCODE_KWARGS_DOCSTRING, BatchEncoding, TextInput, TruncationStrategy from ....utils import logging +from ...utils.safe import regex as re if is_pandas_available(): diff --git a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py index 655bbdc0230f..8ec66edb804b 100644 --- a/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py +++ b/src/transformers/models/depth_pro/convert_depth_pro_weights_to_hf.py @@ -17,7 +17,6 @@ import os from typing import Optional -import regex as re import torch from huggingface_hub import hf_hub_download @@ -26,6 +25,7 @@ DepthProForDepthEstimation, DepthProImageProcessorFast, ) +from transformers.utils.safe import regex as re # fmt: off diff --git a/src/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py index 004a1c36f59c..9255227f997c 100644 --- a/src/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py @@ -18,10 +18,9 @@ import os from typing import Optional -import regex - from ...tokenization_utils import PreTrainedTokenizer from ...utils import logging, requires_backends +from ...utils.safe import regex logger = logging.get_logger(__name__) diff --git a/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py b/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py index 39496fe043ed..a332031ea6e0 100644 --- a/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py +++ b/src/transformers/models/got_ocr2/convert_got_ocr2_weights_to_hf.py @@ -18,7 +18,6 @@ import os from typing import Optional -import regex as re import torch from huggingface_hub import snapshot_download from safetensors import safe_open @@ -33,6 +32,7 @@ ) from transformers.convert_slow_tokenizer import TikTokenConverter from transformers.tokenization_utils import AddedToken +from transformers.utils.safe import regex as re if is_vision_available(): diff --git a/src/transformers/models/gpt2/tokenization_gpt2.py b/src/transformers/models/gpt2/tokenization_gpt2.py index 608164ef2d83..8387f71d4e4b 100644 --- a/src/transformers/models/gpt2/tokenization_gpt2.py +++ b/src/transformers/models/gpt2/tokenization_gpt2.py @@ -19,10 +19,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py index 736a95247dfb..63d1ccc32500 100644 --- a/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py +++ b/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py @@ -19,7 +19,6 @@ from pathlib import Path from typing import Optional -import regex as re import tiktoken import torch from safetensors.torch import load_file as safe_load @@ -31,6 +30,7 @@ PreTrainedTokenizerFast, ) from transformers.convert_slow_tokenizer import TikTokenConverter +from transformers.utils.safe import regex as re # fmt: off diff --git a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py index 7877c5b4668d..f847757e3952 100644 --- a/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/tokenization_layoutlmv3.py @@ -19,8 +19,6 @@ from functools import lru_cache from typing import Optional, Union -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils_base import ( BatchEncoding, @@ -31,6 +29,7 @@ TruncationStrategy, ) from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/led/tokenization_led.py b/src/transformers/models/led/tokenization_led.py index d110ac30d969..8982a526c2bc 100644 --- a/src/transformers/models/led/tokenization_led.py +++ b/src/transformers/models/led/tokenization_led.py @@ -19,11 +19,10 @@ from functools import lru_cache from typing import Optional, Union -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils_base import BatchEncoding, EncodedInput from ...utils import PaddingStrategy, logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/longformer/tokenization_longformer.py b/src/transformers/models/longformer/tokenization_longformer.py index 104bdd7a9b99..63b81e2ad575 100644 --- a/src/transformers/models/longformer/tokenization_longformer.py +++ b/src/transformers/models/longformer/tokenization_longformer.py @@ -18,10 +18,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 4bb19bb5ee73..fcfea6fb1c7e 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -22,7 +22,6 @@ from typing import Optional, Union import numpy as np -import regex as re from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils_base import ( @@ -38,6 +37,7 @@ to_py_obj, ) from ...utils import add_end_docstrings, is_torch_tensor, logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/markuplm/tokenization_markuplm.py b/src/transformers/models/markuplm/tokenization_markuplm.py index 0a6f7c3bd6a0..bec10fc86660 100644 --- a/src/transformers/models/markuplm/tokenization_markuplm.py +++ b/src/transformers/models/markuplm/tokenization_markuplm.py @@ -19,8 +19,6 @@ from functools import lru_cache from typing import Optional, Union -import regex as re - from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...tokenization_utils_base import ( @@ -33,6 +31,7 @@ TruncationStrategy, ) from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py index c773d0514f81..72c73b878446 100644 --- a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py +++ b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py @@ -19,7 +19,6 @@ import os from typing import Optional -import regex as re import torch import torch.nn.functional as F @@ -33,6 +32,7 @@ from transformers.convert_slow_tokenizer import TikTokenConverter from transformers.models.mllama.configuration_mllama import MllamaTextConfig, MllamaVisionConfig from transformers.models.mllama.image_processing_mllama import get_all_supported_aspect_ratios +from transformers.utils.safe import regex as re # fmt: off diff --git a/src/transformers/models/mvp/tokenization_mvp.py b/src/transformers/models/mvp/tokenization_mvp.py index f6039df2dc02..4b6397b9f44c 100644 --- a/src/transformers/models/mvp/tokenization_mvp.py +++ b/src/transformers/models/mvp/tokenization_mvp.py @@ -18,10 +18,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index 373aa6cb6e45..d8cc4c477f01 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -15,7 +15,6 @@ import json import os -import regex as re import torch from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from safetensors.torch import load_file as safe_load_file @@ -28,6 +27,7 @@ PixtralProcessor, PixtralVisionConfig, ) +from transformers.utils.safe import regex as re """ diff --git a/src/transformers/models/qwen2/tokenization_qwen2.py b/src/transformers/models/qwen2/tokenization_qwen2.py index be121adb5442..1d935f856d51 100644 --- a/src/transformers/models/qwen2/tokenization_qwen2.py +++ b/src/transformers/models/qwen2/tokenization_qwen2.py @@ -20,10 +20,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py b/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py index 4871311e114b..e056c6c740a0 100644 --- a/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py +++ b/src/transformers/models/recurrent_gemma/convert_recurrent_gemma_to_hf.py @@ -19,6 +19,7 @@ from accelerate import init_empty_weights from transformers import GemmaTokenizer, RecurrentGemmaConfig, RecurrentGemmaForCausalLM +from transformers.utils.safe import regex as re try: @@ -30,8 +31,6 @@ ) GemmaTokenizerFast = None -import regex as re - """ Sample usage: @@ -63,6 +62,7 @@ num_hidden_layers=26, ) + gemma_7b_config = RecurrentGemmaConfig() CONFIG_MAPPING = {"2B": gemma_2b_config, "7B": gemma_7b_config} diff --git a/src/transformers/models/roberta/tokenization_roberta.py b/src/transformers/models/roberta/tokenization_roberta.py index 67cdcbbf488a..c356fc28b0f7 100644 --- a/src/transformers/models/roberta/tokenization_roberta.py +++ b/src/transformers/models/roberta/tokenization_roberta.py @@ -19,10 +19,9 @@ from functools import lru_cache from typing import Optional -import regex as re - from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re logger = logging.get_logger(__name__) diff --git a/src/transformers/models/whisper/english_normalizer.py b/src/transformers/models/whisper/english_normalizer.py index 265ea04b5334..0f9594db7342 100644 --- a/src/transformers/models/whisper/english_normalizer.py +++ b/src/transformers/models/whisper/english_normalizer.py @@ -20,7 +20,7 @@ from re import Match from typing import Optional, Union -import regex +from ...utils.safe import regex # non-ASCII letters that are not separated by "NFKD" normalization diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 34d9a8965be8..3984aff53096 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -21,10 +21,10 @@ from typing import Optional, Union import numpy as np -import regex as re from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging +from ...utils.safe import regex as re from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 135f20bf4cf9..5f7e4d10f828 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -19,8 +19,6 @@ from pathlib import Path from typing import Optional, Union, get_args -import regex as re - from .doc import ( MODELS_TO_PIPELINE, PIPELINE_TASKS_TO_SAMPLE_DOCSTRINGS, @@ -28,6 +26,7 @@ _prepare_output_docstrings, ) from .generic import ModelOutput +from .safe import regex as re PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers" diff --git a/src/transformers/utils/safe.py b/src/transformers/utils/safe.py new file mode 100644 index 000000000000..661e494f5581 --- /dev/null +++ b/src/transformers/utils/safe.py @@ -0,0 +1,137 @@ +# Copyright 2025 ModelCloud.ai team and The HuggingFace Inc. team. +## +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Thread-safe utilities and module wrappers usable across Transformers.""" + +from __future__ import annotations + +import threading +from functools import wraps +from types import ModuleType + +import regex as _regex + + +__all__ = ["ThreadSafe", "SafeRegex", "regex"] + + +class ThreadSafe(ModuleType): + """Generic proxy that exposes a module through a shared lock.""" + + def __init__(self, module: ModuleType): + super().__init__(module.__name__) + # `_hf_safe_` prefix is used to avoid colliding with the wrapped object namespace. + self._hf_safe_module = module + # Callable execution lock (re-entrant so wrapped code can re-enter safely) + self._hf_safe_lock = threading.RLock() + # Cache dict lock + self._hf_safe_callable_cache_lock = threading.Lock() + self._hf_safe_callable_cache: dict[str, object] = {} + # Retain module metadata so introspection tools relying on attributes + # like __doc__, __spec__, etc, can see the original values. + metadata = {"__doc__": module.__doc__} + for attr in ("__package__", "__file__", "__spec__"): + if hasattr(module, attr): + metadata[attr] = getattr(module, attr) + self.__dict__.update(metadata) + + def __getattr__(self, name: str): + attr = getattr(self._hf_safe_module, name) + if callable(attr): + with self._hf_safe_callable_cache_lock: + cached = self._hf_safe_callable_cache.get(name) + if cached is not None and getattr(cached, "__wrapped__", None) is attr: + return cached + + @wraps(attr) + def _hf_safe_locked(*args, **kwargs): + with self._hf_safe_lock: + return attr(*args, **kwargs) + + _hf_safe_locked.__wrapped__ = attr + self._hf_safe_callable_cache[name] = _hf_safe_locked + return _hf_safe_locked + return attr + + def __dir__(self): + return sorted(set(super().__dir__()) | set(dir(self._hf_safe_module))) + + +class _ThreadSafeProxy: + """Lightweight proxy that serializes access to an object with a shared lock.""" + + def __init__(self, value, lock): + # `_hf_safe_` prefix is used to avoid colliding with the wrapped object namespace. + object.__setattr__(self, "_hf_safe_value", value) + object.__setattr__(self, "_hf_safe_lock", lock) + object.__setattr__(self, "_hf_safe_cache_lock", threading.Lock()) + object.__setattr__(self, "_hf_safe_callable_cache", {}) + object.__setattr__(self, "__wrapped__", value) + + def __getattr__(self, name: str): + attr = getattr(self._hf_safe_value, name) + if callable(attr): + with self._hf_safe_cache_lock: + cached = self._hf_safe_callable_cache.get(name) + if cached is not None and getattr(cached, "__wrapped__", None) is attr: + return cached + + @wraps(attr) + def _hf_safe_locked(*args, **kwargs): + with self._hf_safe_lock: + return attr(*args, **kwargs) + + _hf_safe_locked.__wrapped__ = attr + self._hf_safe_callable_cache[name] = _hf_safe_locked + return _hf_safe_locked + return attr + + def __setattr__(self, name, value): + with self._hf_safe_lock: + setattr(self._hf_safe_value, name, value) + + def __delattr__(self, name): + with self._hf_safe_lock: + delattr(self._hf_safe_value, name) + + def __dir__(self): + with self._hf_safe_lock: + return dir(self._hf_safe_value) + + def __repr__(self): + with self._hf_safe_lock: + return repr(self._hf_safe_value) + + def __call__(self, *args, **kwargs): + with self._hf_safe_lock: + return self._hf_safe_value(*args, **kwargs) + + +class SafeRegex(ThreadSafe): + """Proxy module that exposes ``regex`` through a shared lock.""" + + # We must proxy the shared regex lock to any objects returned here since + # compiled patterns expose methods (e.g. pattern.match) that must also be + # serialized. + + def compile(self, *args, **kwargs): + pattern = self._hf_safe_module.compile(*args, **kwargs) + return _ThreadSafeProxy(pattern, self._hf_safe_lock) + + def Regex(self, *args, **kwargs): + pattern = self._hf_safe_module.Regex(*args, **kwargs) + return _ThreadSafeProxy(pattern, self._hf_safe_lock) + + +regex = SafeRegex(_regex) diff --git a/tests/utils/test_safe.py b/tests/utils/test_safe.py new file mode 100644 index 000000000000..a4e83b7fa63b --- /dev/null +++ b/tests/utils/test_safe.py @@ -0,0 +1,180 @@ +# Copyright 2025 ModelCloud.ai team and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import sys +import threading +import types + +import pytest + +from transformers.utils.safe import ThreadSafe, regex + + +pytestmark = pytest.mark.skipif( + not hasattr(sys, "_is_gil_enabled") or sys._is_gil_enabled(), + reason="Safe regex test only runs when the GIL is disabled", +) + + +def _exercise_pattern(pattern_factory): + expected = {"group0": "test123", "prefix": "test", "number": "123"} + + num_threads = 32 + errors = [] + errors_lock = threading.Lock() + ready_group = threading.Barrier(num_threads + 1) + run_event = threading.Event() + + def worker(): + try: + ready_group.wait() + run_event.wait() + for _ in range(100): + pattern = pattern_factory() + match = pattern.match("test123") + assert match is not None + assert match.group(0) == expected["group0"] + assert match.group("prefix") == expected["prefix"] + assert match.group("number") == expected["number"] + except Exception as exc: + with errors_lock: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for thread in threads: + thread.start() + + ready_group.wait() # Ensure every thread is ready before triggering execution + run_event.set() + + for thread in threads: + thread.join() + + return errors + + +def test_regex_thread_safety_under_gil0(): + def factory(): + return regex.compile(r"(?Ptest)(?P\d+)") + + errors = _exercise_pattern(factory) + assert not errors + + +def test_regex_thread_safety_shared_pattern_under_gil0(): + shared_pattern = regex.compile(r"(?Ptest)(?P\d+)") + + def factory(): + return shared_pattern + + errors = _exercise_pattern(factory) + assert not errors + + +def test_regex_thread_safety_direct_match_under_gil0(): + def factory(): + class _DirectMatcher: + def match(self, text): + return regex.match(r"(?Ptest)(?P\d+)", text) + + return _DirectMatcher() + + errors = _exercise_pattern(factory) + assert not errors + + +def test_regex_threadsafe_allows_reentrant_calls_under_gil0(): + pattern = regex.compile(r"(?Ptest)(?P\d+)") + completed = threading.Event() + + def target(): + def replace(match): + inner = regex.match(r"(?Ptest)(?P\d+)", match.group(0)) + return inner.group(0) + + result = pattern.sub(replace, "test123") + assert result == "test123" + completed.set() + + worker = threading.Thread(target=target) + worker.start() + worker.join(timeout=2) + + assert completed.is_set(), "Re-entrant call should not deadlock" + + +def test_threadsafe_callable_cache_is_serialized_under_gil0(): + module = types.ModuleType("_threadsafe_test_module") + + counter_lock = threading.Lock() + call_counter = {"count": 0} + + def increment(value: int): + with counter_lock: + call_counter["count"] += 1 + return value + 1 + + module.increment = increment + thread_safe_module = ThreadSafe(module) + + num_threads = 32 + errors = [] + errors_lock = threading.Lock() + ready_group = threading.Barrier(num_threads + 1) + run_event = threading.Event() + + def worker(): + try: + ready_group.wait() + run_event.wait() + for _ in range(256): + func = thread_safe_module.increment + assert func(1) == 2 + except Exception as exc: + with errors_lock: + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(num_threads)] + for thread in threads: + thread.start() + + ready_group.wait() # Synchronize thread start to mimic high contention scenarios + run_event.set() + + for thread in threads: + thread.join() + + assert not errors + # Each invocation should be recorded; the exact count confirms every worker executed. + assert call_counter["count"] == num_threads * 256 + + +def test_threadsafe_copies_existing_module_metadata_under_gil0(): + thread_safe_regex = ThreadSafe(regex) + + for attr in ("__package__", "__file__", "__spec__"): + if hasattr(regex, attr): + assert hasattr(thread_safe_regex, attr) + assert getattr(thread_safe_regex, attr) == getattr(regex, attr) + + +def test_threadsafe_skips_missing_module_metadata_under_gil0(): + thread_safe_builtins = ThreadSafe(builtins) + + for attr in ("__package__", "__file__", "__spec__"): + if hasattr(builtins, attr): + assert getattr(thread_safe_builtins, attr) == getattr(builtins, attr) + else: + assert not hasattr(thread_safe_builtins, attr) diff --git a/tests/utils/test_safe_crash.py b/tests/utils/test_safe_crash.py new file mode 100644 index 000000000000..3a3f3509a24c --- /dev/null +++ b/tests/utils/test_safe_crash.py @@ -0,0 +1,92 @@ +# Copyright 2025 ModelCloud.ai team and The HuggingFace Inc. team. +## +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression harness that reproduces the raw `regex` crash when the GIL is off.""" + +import os +import subprocess +import sys + +import pytest + + +pytestmark = pytest.mark.skipif( + not hasattr(sys, "_is_gil_enabled") or sys._is_gil_enabled(), + reason="Crash regression only runs when the GIL is disabled", +) + + +@pytest.mark.xfail(strict=False, reason="Raw regex crashes under PYTHON_GIL=0") +def test_raw_regex_thread_safety_crashes_under_gil0(): + script = r""" +import threading +import regex + +pattern_text = r"(?Ptest)(?P\d+)" +ready_group = threading.Barrier(33) +run_event = threading.Event() + + +def worker(): + ready_group.wait() + run_event.wait() + for _ in range(100): + match = regex.match(pattern_text, "test123") + assert match is not None + assert match.group("prefix") == "test" + assert match.group("number") == "123" + + +threads = [threading.Thread(target=worker) for _ in range(32)] +for thread in threads: + thread.start() + +ready_group.wait() +run_event.set() + +for thread in threads: + thread.join() +""" + + env = dict(os.environ, PYTHON_GIL="0") + result = subprocess.run( + [sys.executable, "-c", script], + env=env, + capture_output=True, + text=True, + ) + + stdout = result.stdout.strip() + stderr = result.stderr.strip() + message_parts = [ + "raw regex subprocess run", + f"return code: {result.returncode}", + ] + if stdout: + message_parts.append(f"stdout:\n{stdout}") + if stderr: + message_parts.append(f"stderr:\n{stderr}") + message = "\n".join(message_parts) + + if result.returncode == 0: + pytest.fail("raw regex unexpectedly behaved thread-safely\n" + message) + + if result.returncode == -11: + message += "\nProcess terminated with SIGSEGV (Segmentation fault)." + + if message: + sys.stderr.write(message + "\n") + sys.stderr.flush() + + pytest.fail(message)