Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH, CHECK_TRANSFORMERS
from infinity_emb.primitives import Device

bettertransformer_available = False
if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer,
BetterTransformerManager,
)
from importlib.metadata import version
transformers_version_string = version('transformers')
Copy link

Copilot AI Jun 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parsing the version string by splitting and int-casting can fail on pre-release or metadata tags. Consider using packaging.version.parse for robust comparisons.

Copilot uses AI. Check for mistakes.
transformers_version = tuple([int(number) for number in transformers_version_string.split(".")])
if transformers_version < (4,49,0):
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer,
BetterTransformerManager,
)
bettertransformer_available = True


if CHECK_TORCH.is_available:
import torch
Expand All @@ -36,6 +43,9 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool:
"""verifies if attempting conversion to bettertransformers should be checked."""
if not engine_args.bettertransformer:
return False

if not bettertransformer_available:
return False

config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=engine_args.model_name_or_path,
Expand All @@ -50,6 +60,12 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo
if not engine_args.bettertransformer:
return model

if not bettertransformer_available:
logger.warning(
"BetterTransformer is not available for transformers package > 4.49"
)
return model

if engine_args.device == Device.mps or (
hasattr(model, "device") and model.device.type == "mps"
):
Expand Down