diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 1d7b7c7f..da9d0176 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -7,11 +7,18 @@ from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH, CHECK_TRANSFORMERS from infinity_emb.primitives import Device +bettertransformer_available = False if CHECK_OPTIMUM.is_available: - from optimum.bettertransformer import ( # type: ignore[import-untyped] - BetterTransformer, - BetterTransformerManager, - ) + from importlib.metadata import version + transformers_version_string = version('transformers') + transformers_version = tuple([int(number) for number in transformers_version_string.split(".")]) + if transformers_version < (4,49,0): + from optimum.bettertransformer import ( # type: ignore[import-untyped] + BetterTransformer, + BetterTransformerManager, + ) + bettertransformer_available = True + if CHECK_TORCH.is_available: import torch @@ -36,6 +43,9 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool: """verifies if attempting conversion to bettertransformers should be checked.""" if not engine_args.bettertransformer: return False + + if not bettertransformer_available: + return False config = AutoConfig.from_pretrained( pretrained_model_name_or_path=engine_args.model_name_or_path, @@ -50,6 +60,12 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo if not engine_args.bettertransformer: return model + if not bettertransformer_available: + logger.warning( + "BetterTransformer is not available for transformers package > 4.49" + ) + return model + if engine_args.device == Device.mps or ( hasattr(model, "device") and model.device.type == "mps" ):