diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 765918d6d..1a057ed86 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -18,6 +18,7 @@ on: - "Cargo.lock" - "rust-toolchain.toml" - "Dockerfile" + - "Dockerfile-neuron" branches: - "main" diff --git a/.github/workflows/integration-test.yaml b/.github/workflows/integration-test-habana.yaml similarity index 90% rename from .github/workflows/integration-test.yaml rename to .github/workflows/integration-test-habana.yaml index 10d894d2a..ccc1b8717 100644 --- a/.github/workflows/integration-test.yaml +++ b/.github/workflows/integration-test-habana.yaml @@ -1,4 +1,4 @@ -name: Run integration tests +name: Run Habana integration tests on: workflow_dispatch: @@ -28,4 +28,4 @@ jobs: working-directory: integration_tests run: | uv sync --locked --all-extras --dev - uv run pytest --durations=0 -sv . + uv run pytest --durations=0 -sv gaudi/ diff --git a/.github/workflows/integration-test-neuron.yaml b/.github/workflows/integration-test-neuron.yaml new file mode 100644 index 000000000..8be3630e2 --- /dev/null +++ b/.github/workflows/integration-test-neuron.yaml @@ -0,0 +1,33 @@ +name: Run Neuron integration tests + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' # Run the workflow nightly to check Neuron integration is working + +jobs: + tests: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: + group: aws-inf2-8xlarge + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + + - name: Build Docker image for Neuron + run: | + docker build . -f Dockerfile-neuron -t tei-neuron + + - name: Run integration tests + working-directory: integration_tests + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + DOCKER_IMAGE: tei-neuron + run: | + uv sync --locked --all-extras --dev + uv run pytest --durations=0 -sv neuron/ diff --git a/.github/workflows/matrix.json b/.github/workflows/matrix.json index 9449b53be..f9912b5d9 100644 --- a/.github/workflows/matrix.json +++ b/.github/workflows/matrix.json @@ -105,5 +105,13 @@ "extraBuildArgs": "PLATFORM=hpu", "grpc": true, "dockerfile": "Dockerfile-intel" + }, + { + "name": "neuron", + "imageNamePrefix": "neuron-", + "runOn": "always", + "sccache": true, + "grpc": true, + "dockerfile": "Dockerfile-neuron" } ] diff --git a/Dockerfile-neuron b/Dockerfile-neuron new file mode 100644 index 000000000..112c742b7 --- /dev/null +++ b/Dockerfile-neuron @@ -0,0 +1,189 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.92-bookworm AS chef +WORKDIR /usr/src + +ENV SCCACHE=0.10.0 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache + +# Download, configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +FROM chef AS planner + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG SCCACHE_GHA_ENABLED + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo chef cook --release --features python-neuron --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +FROM builder AS http-builder + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F python-neuron -F http --no-default-features && sccache -s + +FROM builder AS grpc-builder + +COPY proto proto + +RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ + --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ + cargo build --release --bin text-embeddings-router -F grpc -F python-neuron --no-default-features && sccache -s + +FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 + +ENV PATH="/usr/local/bin:/root/.local/bin:${PATH}" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + build-essential \ + git \ + curl \ + cmake \ + pkg-config \ + protobuf-compiler \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN ln -s /usr/bin/python3 /usr/local/bin/python || true +RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true + +WORKDIR /usr/src +COPY backends backends +COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py +COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml +RUN cd backends/python/server && \ + make install + +ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58 +ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008 +ARG NEURONX_TOOLS_VERSION=2.26.14.0 + +ARG NEURONX_CC_VERSION=2.21.33363.0+82129205 +ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.16998+e9bf8a50 +ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf + +RUN apt-get update \ + && apt-get upgrade -y \ + && apt-get install -y --no-install-recommends \ + apt-transport-https \ + build-essential \ + ca-certificates \ + cmake \ + curl \ + emacs \ + git \ + gnupg2 \ + gpg-agent \ + jq \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libcap-dev \ + libhwloc-dev \ + openjdk-11-jdk \ + unzip \ + vim \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +# Ubuntu 22.04 = jammy; use signed-by (apt-key is deprecated) +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | gpg --dearmor -o /usr/share/keyrings/neuron-archive-keyring.gpg && \ + echo "deb [signed-by=/usr/share/keyrings/neuron-archive-keyring.gpg] https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list + +RUN apt-get update \ + && apt-get install -y \ + aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ + aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ + aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +ENV PATH="/opt/aws/neuron/bin:${PATH}" + +RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \ + --extra-index-url https://pypi.org/simple \ + --trusted-host pip.repos.neuron.amazonaws.com \ + neuronx-cc==$NEURONX_CC_VERSION \ + torch-neuronx==$NEURONX_FRAMEWORK_VERSION \ + torchvision \ + neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \ + && rm -rf ~/.cache/pip/* + +# HF ARGS +# Note: optimum-neuron 0.4.4 requires transformers~=4.57.1 +ARG TRANSFORMERS_VERSION=4.57.1 +ARG DIFFUSERS_VERSION=0.35.2 +ARG HUGGINGFACE_HUB_VERSION=0.36.0 +ARG OPTIMUM_NEURON_VERSION=0.4.4 +ARG SENTENCE_TRANSFORMERS=5.1.2 +ARG PEFT_VERSION=0.17.0 +ARG DATASETS_VERSION=4.1.1 + +# Install Hugging Face libraries and dependencies for TEI on Neuron +RUN pip install --no-cache-dir -U \ + networkx==2.8.8 \ + transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \ + diffusers==${DIFFUSERS_VERSION} \ + compel \ + controlnet-aux \ + huggingface_hub==${HUGGINGFACE_HUB_VERSION} \ + hf_transfer \ + datasets==${DATASETS_VERSION} \ + optimum-neuron==${OPTIMUM_NEURON_VERSION} \ + sentence_transformers==${SENTENCE_TRANSFORMERS} \ + peft==${PEFT_VERSION} \ + && rm -rf ~/.cache/pip/* + +FROM neuron AS grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM neuron AS http + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/Cargo.toml b/backends/Cargo.toml index bb9d74191..fd0ab74ae 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -21,6 +21,7 @@ rand = { workspace = true } [features] clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] +python-neuron = ["dep:text-embeddings-backend-python"] ort = ["dep:text-embeddings-backend-ort"] candle = ["dep:text-embeddings-backend-candle"] cuda = ["text-embeddings-backend-candle?/cuda"] diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 1e919f233..8845163eb 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -11,24 +11,28 @@ from text_embeddings_server.models.masked_model import MaskedLanguageModel from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel -from text_embeddings_server.models.jinaBert_model import FlashJinaBert -from text_embeddings_server.models.flash_mistral import FlashMistral -from text_embeddings_server.models.flash_qwen3 import FlashQwen3 -from text_embeddings_server.utils.device import get_device, use_ipex +from text_embeddings_server.models.habana import wrap_model_if_hpu + +from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] -DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [ - "true", - "1", -] -# Disable gradients -torch.set_grad_enabled(False) +# Flash Attention models - only available when flash_attn is installed FLASH_ATTENTION = True +FlashBert = None +FlashJinaBert = None +FlashMistral = None +FlashQwen3 = None + try: from text_embeddings_server.models.flash_bert import FlashBert + from text_embeddings_server.models.jinaBert_model import FlashJinaBert + from text_embeddings_server.models.flash_mistral import FlashMistral + from text_embeddings_server.models.flash_qwen3 import FlashQwen3 + # Disable gradients + torch.set_grad_enabled(False) except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False @@ -36,16 +40,16 @@ if FLASH_ATTENTION: __all__.append(FlashBert) +# Neuron models - only import when on Neuron device to avoid unnecessary dependencies +create_neuron_model = None -def wrap_model_if_hpu(model_handle, device): - """Wrap the model in HPU graph if the device is HPU.""" - if device.type == "hpu": - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - model_handle.model = wrap_in_hpu_graph( - model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE +if is_neuron(): + try: + from text_embeddings_server.models.neuron import ( + create_neuron_model, ) - return model_handle + except ImportError as e: + logger.warning(f"Could not import Neuron models: {e}") def create_model(model_class, model_path, device, datatype, pool="cls"): @@ -75,8 +79,26 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + # Neuron cases - use optimum-neuron for all supported model types + if is_neuron(): + logger.info(f"Neuron device detected, using optimum-neuron backend for model type: {config.model_type}") + try: + return create_neuron_model( + model_path=model_path, + device=device, + dtype=datatype, + pool=pool, + trust_remote=TRUST_REMOTE_CODE, + config=config, + ) + except Exception as e: + logger.warning(f"Failed to load model with optimum-neuron: {e}") + logger.warning("Falling back to default model loading path") + # Fall through to default model loading + if ( - hasattr(config, "auto_map") + FlashJinaBert is not None + and hasattr(config, "auto_map") and isinstance(config.auto_map, dict) and "AutoModel" in config.auto_map and config.auto_map["AutoModel"] @@ -116,13 +138,13 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "mistral" and device.type == "hpu": + if config.model_type == "mistral" and device.type == "hpu" and FlashMistral is not None: try: return create_model(FlashMistral, model_path, device, datatype, pool) except FileNotFoundError: return create_model(DefaultModel, model_path, device, datatype, pool) - if config.model_type == "qwen3" and device.type == "hpu": + if config.model_type == "qwen3" and device.type == "hpu" and FlashQwen3 is not None: try: return create_model(FlashQwen3, model_path, device, datatype, pool) except FileNotFoundError: diff --git a/backends/python/server/text_embeddings_server/models/habana/__init__.py b/backends/python/server/text_embeddings_server/models/habana/__init__.py new file mode 100644 index 000000000..267830de1 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/habana/__init__.py @@ -0,0 +1,14 @@ +import os + +DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in ["true", "1"] + + +def wrap_model_if_hpu(model_handle, device): + """Wrap the model in HPU graph if the device is HPU.""" + if device.type == "hpu": + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model_handle.model = wrap_in_hpu_graph( + model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE + ) + return model_handle diff --git a/backends/python/server/text_embeddings_server/models/neuron/__init__.py b/backends/python/server/text_embeddings_server/models/neuron/__init__.py new file mode 100644 index 000000000..80745edc8 --- /dev/null +++ b/backends/python/server/text_embeddings_server/models/neuron/__init__.py @@ -0,0 +1,395 @@ +import inspect +import os +import torch + +from abc import ABC +from pathlib import Path +from typing import Type, List +from opentelemetry import trace +from loguru import logger + +from text_embeddings_server.models.model import Model +from text_embeddings_server.models.types import PaddedBatch, Embedding, Score + +tracer = trace.get_tracer(__name__) + +# Neuron static shapes compilation parameters +NEURON_BATCH_SIZE = int(os.getenv("NEURON_BATCH_SIZE", "1")) +NEURON_SEQUENCE_LENGTH = int(os.getenv("NEURON_SEQUENCE_LENGTH", "512")) + + +class NeuronBaseModel(Model, ABC): + """Base class for all Neuron models.""" + + def __init__( + self, + model, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + ): + self.hidden_size = model.config.hidden_size + + # Calculate max input length based on model type + position_offset = 0 + model_type = model.config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = getattr(model.config, "pad_token_id", 1) + 1 + + if hasattr(model.config, "max_seq_length"): + self.max_input_length = model.config.max_seq_length + elif hasattr(model.config, "n_positions"): + self.max_input_length = model.config.n_positions + else: + self.max_input_length = ( + model.config.max_position_embeddings - position_offset + ) + + # Check which inputs the model supports + self.has_position_ids = self._check_param_exists(model, "position_ids") + self.has_token_type_ids = self._check_param_exists(model, "token_type_ids") + + super().__init__(model=model, dtype=dtype, device=device) + + @staticmethod + def _check_param_exists(model, param_name: str) -> bool: + """Check if a parameter exists in the model's forward signature.""" + try: + forward_fn = model.forward if hasattr(model, 'forward') else model.__call__ + return ( + inspect.signature(forward_fn).parameters.get(param_name, None) + is not None + ) + except (ValueError, TypeError): + return False + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + def _prepare_inputs(self, batch: PaddedBatch) -> dict: + """Prepare input kwargs for model forward pass. + + Note: Neuron models require int64 (long) tensors for inputs. + """ + kwargs = { + "input_ids": batch.input_ids.to(torch.long), + "attention_mask": batch.attention_mask.to(torch.long), + } + if self.has_token_type_ids: + kwargs["token_type_ids"] = batch.token_type_ids.to(torch.long) + if self.has_position_ids: + kwargs["position_ids"] = batch.position_ids.to(torch.long) + return kwargs + + +class NeuronSentenceTransformersModel(Model): + """ + Neuron model for sentence-transformers. + + Uses optimum.neuron.NeuronSentenceTransformers which is designed + for sentence embedding models. + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = False, + ): + from optimum.neuron import NeuronSentenceTransformers + from transformers import AutoConfig + + # Load config separately for reliable access + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote) + self.hidden_size = config.hidden_size + + position_offset = 0 + model_type = config.model_type + if model_type in ["xlm-roberta", "camembert", "roberta"]: + position_offset = getattr(config, "pad_token_id", 1) + 1 + + if hasattr(config, "max_seq_length"): + self.max_input_length = config.max_seq_length + elif hasattr(config, "n_positions"): + self.max_input_length = config.n_positions + else: + self.max_input_length = ( + config.max_position_embeddings - position_offset + ) + + is_compiled = self._is_neuron_compiled(model_path) + if not is_compiled: + logger.info(f"Compiling model for Neuron with batch_size={NEURON_BATCH_SIZE}, sequence_length={NEURON_SEQUENCE_LENGTH}") + model = NeuronSentenceTransformers.from_pretrained( + model_path, + export=True, + batch_size=NEURON_BATCH_SIZE, + sequence_length=NEURON_SEQUENCE_LENGTH, + ) + else: + model = NeuronSentenceTransformers.from_pretrained(model_path) + + self.pool = pool + super().__init__(model=model, dtype=dtype, device=device) + logger.info(f"Loaded NeuronSentenceTransformersModel with pool={pool}, hidden_size={self.hidden_size}") + + @staticmethod + def _is_neuron_compiled(model_path: Path) -> bool: + """Check if the model is already compiled for Neuron.""" + neuron_files = list(model_path.glob("*.neuron")) if model_path.is_dir() else [] + return len(neuron_files) > 0 + + @property + def batch_type(self) -> Type[PaddedBatch]: + return PaddedBatch + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + # Prepare inputs + input_ids = batch.input_ids.to(torch.long) + attention_mask = batch.attention_mask.to(torch.long) + + output = self.model(input_ids, attention_mask) + + # Get sentence embeddings from output + sentence_embedding = None + if isinstance(output, dict): + # Check if sentence_embedding exists and has non-zero values + # NeuronSentenceTransformers may return zeros for sentence_embedding when pooling fails + has_valid_sentence_embedding = ( + "sentence_embedding" in output + and output["sentence_embedding"] is not None + and output["sentence_embedding"].abs().sum() > 0 + ) + if has_valid_sentence_embedding: + sentence_embedding = output["sentence_embedding"] + elif "token_embeddings" in output and output["token_embeddings"] is not None: + # Apply manual pooling when sentence_embedding is not valid + logger.debug(f"Using token_embeddings with manual {self.pool} pooling") + token_embeddings = output["token_embeddings"] + + if self.pool == "cls": + sentence_embedding = token_embeddings[:, 0, :] + elif self.pool == "mean": + mask = attention_mask.unsqueeze(-1).float() + sentence_embedding = (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1) + elif self.pool == "last_token": + seq_lengths = attention_mask.sum(dim=1) - 1 + sentence_embedding = token_embeddings[torch.arange(token_embeddings.size(0)), seq_lengths] + else: + raise ValueError(f"Invalid pooling mode: {self.pool}") + else: + raise ValueError(f"Cannot extract embeddings from model output dict: {output.keys()}") + elif hasattr(output, "sentence_embedding") and output.sentence_embedding is not None: + sentence_embedding = output.sentence_embedding + elif hasattr(output, "token_embeddings") and output.token_embeddings is not None: + token_embeddings = output.token_embeddings + if self.pool == "cls": + sentence_embedding = token_embeddings[:, 0, :] + elif self.pool == "mean": + mask = attention_mask.unsqueeze(-1).float() + sentence_embedding = (token_embeddings * mask).sum(dim=1) / mask.sum(dim=1) + elif self.pool == "last_token": + seq_lengths = attention_mask.sum(dim=1) - 1 + sentence_embedding = token_embeddings[torch.arange(token_embeddings.size(0)), seq_lengths] + else: + raise ValueError(f"Invalid pooling mode: {self.pool}") + elif torch.is_tensor(output): + # Assume output is the sentence embedding tensor directly + sentence_embedding = output + else: + raise ValueError(f"Cannot extract embeddings from model output: type={type(output)}") + + # Convert to list format expected by the gRPC interface + cpu_results = sentence_embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + raise NotImplementedError("Prediction not supported for sentence transformer models") + + +class NeuronClassificationModel(NeuronBaseModel): + """ + Neuron-optimized model for sequence classification. + + Uses optimum.neuron.NeuronModelForSequenceClassification for classification tasks. + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = False, + ): + from optimum.neuron import NeuronModelForSequenceClassification + + is_compiled = self._is_neuron_compiled(model_path) + export_kwargs = {} + if not is_compiled: + export_kwargs = { + "export": True, + "batch_size": NEURON_BATCH_SIZE, + "sequence_length": NEURON_SEQUENCE_LENGTH, + } + logger.info(f"Compiling model for Neuron with batch_size={NEURON_BATCH_SIZE}, sequence_length={NEURON_SEQUENCE_LENGTH}") + model = NeuronModelForSequenceClassification.from_pretrained( + model_path, + **export_kwargs, + ) + + super().__init__(model, model_path, device, dtype) + logger.info("Loaded NeuronClassificationModel") + + @staticmethod + def _is_neuron_compiled(model_path: Path) -> bool: + """Check if the model is already compiled for Neuron.""" + neuron_files = list(model_path.glob("*.neuron")) if model_path.is_dir() else [] + return len(neuron_files) > 0 + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + raise NotImplementedError("Embedding not supported for classification models") + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + kwargs = self._prepare_inputs(batch) + output = self.model(**kwargs) + + # Get logits from output + if hasattr(output, "logits"): + logits = output.logits + else: + logits = output[0] + + all_scores = logits.tolist() + return [Score(values=scores) for scores in all_scores] + + +class NeuronMaskedLMModel(NeuronBaseModel): + """ + Neuron-optimized model for Masked Language Modeling (SPLADE). + + Uses optimum.neuron.NeuronModelForMaskedLM for SPLADE-style sparse embeddings. + """ + + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "splade", + trust_remote: bool = False, + ): + from optimum.neuron import NeuronModelForMaskedLM + + is_compiled = self._is_neuron_compiled(model_path) + export_kwargs = {} + if not is_compiled: + export_kwargs = { + "export": True, + "batch_size": NEURON_BATCH_SIZE, + "sequence_length": NEURON_SEQUENCE_LENGTH, + } + logger.info(f"Compiling model for Neuron with batch_size={NEURON_BATCH_SIZE}, sequence_length={NEURON_SEQUENCE_LENGTH}") + model = NeuronModelForMaskedLM.from_pretrained( + model_path, + **export_kwargs, + ) + + super().__init__(model, model_path, device, dtype) + + # Get vocab size for SPLADE output + self.vocab_size = model.config.vocab_size + logger.info(f"Loaded NeuronMaskedLMModel with vocab_size={self.vocab_size}") + + @staticmethod + def _is_neuron_compiled(model_path: Path) -> bool: + """Check if the model is already compiled for Neuron.""" + neuron_files = list(model_path.glob("*.neuron")) if model_path.is_dir() else [] + return len(neuron_files) > 0 + + @tracer.start_as_current_span("embed") + def embed(self, batch: PaddedBatch) -> List[Embedding]: + kwargs = self._prepare_inputs(batch) + output = self.model(**kwargs) + + # Get logits for SPLADE pooling + if hasattr(output, "logits"): + hidden_states = output.logits + else: + hidden_states = output[0] + + # SPLADE pooling: ReLU -> log(1+x) -> max pooling + hidden_states = torch.relu(hidden_states) + hidden_states = (1 + hidden_states).log() + hidden_states = torch.mul(hidden_states, batch.attention_mask.unsqueeze(-1)) + sparse_embedding = hidden_states.max(dim=1).values + + cpu_results = sparse_embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.vocab_size : (i + 1) * self.vocab_size] + ) + for i in range(len(batch)) + ] + + @tracer.start_as_current_span("predict") + def predict(self, batch: PaddedBatch) -> List[Score]: + raise NotImplementedError("Prediction not supported for masked LM models") + + +def create_neuron_model( + model_path: Path, + device: torch.device, + dtype: torch.dtype, + pool: str = "cls", + trust_remote: bool = False, + config=None, +) -> Model: + """ + Factory function to create the appropriate Neuron model based on the model config. + + Args: + model_path: Path to the model + device: Target device (should be xla for Neuron) + dtype: Data type for the model + pool: Pooling strategy (cls, mean, lasttoken, splade) + trust_remote: Whether to trust remote code + config: Pre-loaded model config (optional) + + Returns: + Appropriate Neuron model instance + """ + from transformers import AutoConfig + + if config is None: + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote) + + architectures = getattr(config, "architectures", []) or [] + architecture = architectures[0] if architectures else "" + + logger.info(f"Creating Neuron model for architecture: {architecture}, pool: {pool}") + + # Check for classification models + if architecture.endswith("ForSequenceClassification") or architecture.endswith("Classification"): + return NeuronClassificationModel(model_path, device, dtype, pool, trust_remote) + + # Check for SPLADE (masked LM) models + if pool == "splade" or architecture.endswith("ForMaskedLM"): + return NeuronMaskedLMModel(model_path, device, dtype, pool, trust_remote) + + # Default to NeuronSentenceTransformers for all embedding models + return NeuronSentenceTransformersModel(model_path, device, dtype, pool, trust_remote) diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py index 3f3b04dd7..4963b012c 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -1,4 +1,6 @@ import os +import re +import functools from loguru import logger import importlib.metadata import importlib.util @@ -49,6 +51,21 @@ def is_hpu() -> bool: is_hpu_available = False return is_hpu_available +@functools.cache +def get_neuron_major() -> int: + MAJORS_FILE = "/proc/devices" + NEURON_MAJOR_LINE = re.compile(r"^\s*(\d+)\s+neuron\s*$") + if not os.path.exists(MAJORS_FILE): + return -1 + with open(MAJORS_FILE, "r") as f: + for l in f.readlines(): + m = NEURON_MAJOR_LINE.match(l) + if m: + return int(m.group(1)) + return -1 + +def is_neuron() -> bool: + return get_neuron_major() > -1 def use_ipex() -> bool: value = os.environ.get("USE_IPEX", "True").lower() @@ -72,5 +89,7 @@ def get_device(): if hasattr(torch, "xpu") and torch.xpu.is_available(): device = torch.device("xpu") + elif is_neuron(): + device = torch.device("xla") return device diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index 80292be79..ef16ca556 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -9,12 +9,18 @@ pub enum DType { // Float16 is not available on accelerate #[cfg(any( feature = "python", + feature = "python-neuron", all(feature = "candle", not(feature = "accelerate")) ))] Float16, - #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] + #[cfg(any( + feature = "python", + feature = "python-neuron", + feature = "candle", + feature = "ort" + ))] Float32, - #[cfg(feature = "python")] + #[cfg(any(feature = "python", feature = "python-neuron"))] Bfloat16, } @@ -24,12 +30,18 @@ impl fmt::Display for DType { // Float16 is not available on accelerate #[cfg(any( feature = "python", + feature = "python-neuron", all(feature = "candle", not(feature = "accelerate")) ))] DType::Float16 => write!(f, "float16"), - #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] + #[cfg(any( + feature = "python", + feature = "python-neuron", + feature = "candle", + feature = "ort" + ))] DType::Float32 => write!(f, "float32"), - #[cfg(feature = "python")] + #[cfg(any(feature = "python", feature = "python-neuron"))] DType::Bfloat16 => write!(f, "bfloat16"), } } @@ -46,12 +58,13 @@ impl Default for DType { feature = "accelerate", feature = "mkl", feature = "ort", - feature = "python" + feature = "python", + feature = "python-neuron" )))] { DType::Float16 } - #[cfg(feature = "python")] + #[cfg(any(feature = "python", feature = "python-neuron"))] { DType::Bfloat16 } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 479b310e6..8f9ee2838 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -28,7 +28,7 @@ use text_embeddings_backend_candle::CandleBackend; #[cfg(feature = "ort")] use text_embeddings_backend_ort::OrtBackend; -#[cfg(feature = "python")] +#[cfg(any(feature = "python", feature = "python-neuron"))] use text_embeddings_backend_python::PythonBackend; fn powers_of_two(max_value: usize) -> Vec { @@ -416,10 +416,43 @@ async fn init_backend( } if let Some(api_repo) = api_repo.as_ref() { - if cfg!(feature = "python") || cfg!(feature = "candle") { - let start = std::time::Instant::now(); + let start = std::time::Instant::now(); + if cfg!(feature = "python-neuron") { + #[cfg(feature = "python-neuron")] + { + tracing::info!("Downloading `model.neuron`"); + let model_files = download_neuron(api_repo) + .await + .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; + + if model_files.is_empty() { + tracing::warn!( + "Neuron model files not found in the repository. \ + The Python backend will attempt to compile the model on-the-fly using optimum-neuron. \ + This may take several minutes. For faster startup, consider pre-compiling your model: \ + https://huggingface.co/docs/optimum-neuron/en/model_doc/sentence_transformers/overview" + ); + // Fall back to downloading regular model files for on-the-fly compilation + if download_safetensors(api_repo.clone()).await.is_err() { + tracing::warn!( + "safetensors weights not found. Using `pytorch_model.bin` instead." + ); + tracing::info!("Downloading `pytorch_model.bin`"); + api_repo + .get("pytorch_model.bin") + .await + .map_err(|err| BackendError::WeightsNotFound(err.to_string()))?; + } + } + + tracing::info!("Neuron model downloaded in {:?}", start.elapsed()); + } + } else if cfg!(feature = "python") || cfg!(feature = "candle") { if download_safetensors(api_repo.clone()).await.is_err() { - tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); + tracing::warn!( + "safetensors weights not found. Using `pytorch_model.bin` instead. \ + Model loading will be significantly slower." + ); tracing::info!("Downloading `pytorch_model.bin`"); api_repo .get("pytorch_model.bin") @@ -494,8 +527,8 @@ async fn init_backend( } } - if cfg!(feature = "python") { - #[cfg(feature = "python")] + if cfg!(feature = "python") || cfg!(feature = "python-neuron") { + #[cfg(any(feature = "python", feature = "python-neuron"))] { let backend = std::thread::spawn(move || { PythonBackend::new( @@ -736,6 +769,21 @@ async fn download_onnx(api: Arc) -> Result, ApiError> { } } +#[cfg(feature = "python-neuron")] +async fn download_neuron(api: &ApiRepo) -> Result, ApiError> { + let mut model_files: Vec = Vec::new(); + + tracing::info!("Downloading `model.neuron`"); + match api.get("model.neuron").await { + Ok(p) => model_files.push(p), + Err(err) => { + tracing::warn!("Could not download `model.neuron`: {err}"); + } + }; + + Ok(model_files) +} + #[cfg(feature = "candle")] #[derive(Debug, Clone, Deserialize, PartialEq)] enum ModuleType { diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fa6f21e63..69ace4e17 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Build custom container for TEI - local: intel_container title: Using TEI container with Intel Hardware + - local: aws_neuron + title: Using TEI with AWS Trainium and Inferentia - local: examples title: Example uses title: Tutorials diff --git a/docs/source/en/aws_neuron.md b/docs/source/en/aws_neuron.md new file mode 100644 index 000000000..2d02999a6 --- /dev/null +++ b/docs/source/en/aws_neuron.md @@ -0,0 +1,105 @@ + +# Using TEI with AWS Trainium and Inferentia + +Text Embeddings Inference (TEI) supports AWS Trainium and Inferentia accelerators through the [optimum-neuron](https://huggingface.co/docs/optimum-neuron) library. + +## Supported Model Types + +- **Embedding models**: Uses `NeuronSentenceTransformers` for sentence embeddings (e.g., BGE, sentence-transformers models) +- **Classification models**: Uses `NeuronModelForSequenceClassification` for sequence classification tasks +- **SPLADE models**: Uses `NeuronModelForMaskedLM` for sparse embeddings + +## Build Docker Image + +To build a container optimized for AWS Neuron devices: + +```shell +docker build . -f Dockerfile-neuron -t tei-neuron:main +``` + +## Deploy with Pre-compiled Models + +Pre-compiled models are recommended for production use as they skip the compilation step and start faster. + +```shell +model='optimum/bge-base-en-v1.5-neuronx' +volume=$PWD/data + +docker run --privileged \ + -p 8080:80 \ + -v $volume:/data \ + tei-neuron:main \ + --model-id $model \ + --dtype float32 +``` + +> **Note**: The `--privileged` flag is required for the Neuron OCI hook to work properly. + +## Deploy with On-the-fly Compilation + +You can also use non-pre-compiled models. TEI will compile the model for Neuron automatically on first load. This takes additional time but allows you to use any compatible model. + +```shell +model='BAAI/bge-base-en-v1.5' +volume=$PWD/data + +docker run --privileged \ + -p 8080:80 \ + -v $volume:/data \ + -e NEURON_BATCH_SIZE=1 \ + -e NEURON_SEQUENCE_LENGTH=512 \ + tei-neuron:main \ + --model-id $model \ + --dtype float32 +``` + +### Compilation Environment Variables + +When using on-the-fly compilation, you can configure the following environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `NEURON_BATCH_SIZE` | 1 | Batch size for Neuron compilation (static shape) | +| `NEURON_SEQUENCE_LENGTH` | 512 | Maximum sequence length for Neuron compilation (static shape) | + +> **Note**: Neuron requires static shapes for compilation. The batch size and sequence length are fixed at compilation time. + +## Runtime Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `NEURON_RT_NUM_CORES` | 1 | Number of Neuron cores to use | +| `NEURON_RT_VISIBLE_CORES` | 0 | Which Neuron cores are visible to the runtime | + +## Pre-compiled Models + +For faster startup, use pre-compiled Neuron models from the Hugging Face Hub like: + +- [optimum/bge-base-en-v1.5-neuronx](https://huggingface.co/optimum/bge-base-en-v1.5-neuronx) + +You can also compile your own models using the [Optimum Neuron guide](https://huggingface.co/docs/optimum-neuron/en/model_doc/sentence_transformers/overview). + +## Testing Your Deployment + +Once the container is running, you can test the embedding endpoint: + +```shell +curl 127.0.0.1:8080/embed \ + -X POST \ + -H 'Content-Type: application/json' \ + -d '{"inputs": "What is Deep Learning?"}' +``` diff --git a/integration_tests/README.md b/integration_tests/README.md index 641d8fce3..ca20fbb9c 100644 --- a/integration_tests/README.md +++ b/integration_tests/README.md @@ -1,8 +1,18 @@ # Integration Tests -This directory contains integration tests for the project. This starts the TEI server and run an /embed request to it while checking the output is as expected. +This directory contains integration tests for the project. This starts the TEI server and runs an /embed request to it while checking the output is as expected. -## Running the tests for HPU +## How Tests Work + +The tests use pytest fixtures to: +1. Start a Docker container with the TEI server +2. Wait for the server to become healthy +3. Send embedding requests and validate responses +4. Stop and remove the container after tests complete + +The Docker image must be built before running tests. The `uv run pytest` command will start containers automatically using the pre-built image. + +## Running the tests for HPU (Habana Gaudi) First you have to build the docker image. ```bash @@ -13,5 +23,27 @@ docker build . -f Dockerfile-intel --build-arg PLATFORM=$platform -t tei_hpu Then you can run the tests. ```bash +cd integration_tests/gaudi +uv run pytest --durations=0 -sv . +``` + +## Running the tests for Neuron (AWS Inferentia/Trainium) + +### Prerequisites + +1. **AWS Neuron instance**: Tests must run on an EC2 instance with Neuron devices (inf2, trn1 or trn2) +2. **Neuron drivers**: Ensure Neuron drivers are installed and `/dev/neuron*` devices are available +3. **Pre-compiled models**: Neuron requires models to be pre-compiled to `.neuron` format + +### Building the Docker Image + +```bash +docker build . -f Dockerfile-neuron -t tei-neuron +``` + +### Running the Tests + +```bash +cd integration_tests/neuron uv run pytest --durations=0 -sv . ``` diff --git a/integration_tests/neuron/conftest.py b/integration_tests/neuron/conftest.py new file mode 100644 index 000000000..40d16b05a --- /dev/null +++ b/integration_tests/neuron/conftest.py @@ -0,0 +1,299 @@ +import asyncio +import contextlib +import os +import shlex +import subprocess +import sys +import threading +import time +from tempfile import TemporaryDirectory + +import docker +import pytest +from docker.errors import NotFound +import logging +from test_embed import TEST_CONFIGS +import aiohttp + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-8s | %(name)s:%(funcName)s:%(lineno)d - %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__file__) + +# Use the latest image from the local docker build +DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", "tei-neuron") +DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", None) + +if DOCKER_VOLUME is None: + logger.warning( + "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing" + ) + +LOG_LEVEL = os.getenv("LOG_LEVEL", "info") + +BASE_ENV = { + "HF_HUB_ENABLE_HF_TRANSFER": "1", + "LOG_LEVEL": LOG_LEVEL, + # Neuron-specific environment variables + "NEURON_RT_NUM_CORES": os.getenv("NEURON_RT_NUM_CORES", "1"), + "NEURON_RT_VISIBLE_CORES": os.getenv("NEURON_RT_VISIBLE_CORES", "0"), +} + +# Neuron requires privileged mode for OCI hook to work +NEURON_RUN_ARGS = { + "privileged": True, +} + + +def stream_container_logs(container, test_name): + """Stream container logs in a separate thread.""" + try: + for log in container.logs(stream=True, follow=True): + print( + f"[TEI Server Logs - {test_name}] {log.decode('utf-8')}", + end="", + file=sys.stderr, + flush=True, + ) + except Exception as e: + logger.error(f"Error streaming container logs: {str(e)}") + + +class LauncherHandle: + def __init__(self, port: int): + self.port = port + self.base_url = f"http://127.0.0.1:{port}" + + async def generate(self, prompt: str): + """Send embed request to the TEI server (alias for embed).""" + return await self.embed(prompt) + + async def embed(self, text: str): + """Send embed request to the TEI server.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/embed", + json={"inputs": text}, + headers={"Content-Type": "application/json"} + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Request failed with status {response.status}: {error_text}") + return await response.json() + + async def embed_batch(self, texts: list): + """Send batch embed request to the TEI server.""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/embed", + json={"inputs": texts}, + headers={"Content-Type": "application/json"} + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Request failed with status {response.status}: {error_text}") + return await response.json() + + async def predict(self, text: str): + """Send predict request to the TEI server (for classification models).""" + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.base_url}/predict", + json={"inputs": text}, + headers={"Content-Type": "application/json"} + ) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError(f"Request failed with status {response.status}: {error_text}") + return await response.json() + + def _inner_health(self): + raise NotImplementedError + + async def health(self, timeout: int = 300): + """Wait for the server to be healthy. + + Neuron models may take longer to compile/load, so default timeout is higher. + """ + assert timeout > 0 + start_time = time.time() + logger.info(f"Starting health check with timeout of {timeout}s") + + for attempt in range(timeout): + if not self._inner_health(): + logger.error("Launcher crashed during health check") + raise RuntimeError("Launcher crashed") + + try: + # Try to make a request using generate (like Habana tests) + await self.generate("test") + elapsed = time.time() - start_time + logger.info(f"Health check passed after {elapsed:.1f}s") + return + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt == timeout - 1: + logger.error(f"Health check failed after {timeout}s: {str(e)}") + raise RuntimeError(f"Health check failed: {str(e)}") + if attempt % 10 == 0 and attempt != 0: # Only log every 10th attempt + logger.debug(f"Connection attempt {attempt}/{timeout} failed: {str(e)}") + await asyncio.sleep(1) + except Exception as e: + logger.error(f"Unexpected error during health check: {str(e)}") + import traceback + logger.error(f"Full traceback:\n{traceback.format_exc()}") + raise + + +class ContainerLauncherHandle(LauncherHandle): + def __init__(self, docker_client, container_name, port: int): + super().__init__(port) + self.docker_client = docker_client + self.container_name = container_name + + def _inner_health(self) -> bool: + try: + container = self.docker_client.containers.get(self.container_name) + status = container.status + if status not in ["running", "created"]: + logger.warning(f"Container status is {status}") + # Get container logs for debugging + logs = container.logs().decode("utf-8") + logger.debug(f"Container logs:\n{logs}") + return False + return True + except Exception as e: + logger.error(f"Error checking container health: {str(e)}") + return False + + +class ProcessLauncherHandle(LauncherHandle): + def __init__(self, process, port: int): + super(ProcessLauncherHandle, self).__init__(port) + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + +@pytest.fixture(scope="module") +def data_volume(): + tmpdir = TemporaryDirectory() + yield tmpdir.name + try: + # Cleanup the temporary directory using sudo as it contains root files created by the container + subprocess.run(shlex.split(f"sudo rm -rf {tmpdir.name}"), check=True) + except subprocess.CalledProcessError as e: + logger.error(f"Error cleaning up temporary directory: {str(e)}") + + +@pytest.fixture(scope="function") +def neuron_launcher(): + @contextlib.contextmanager + def docker_launcher( + model_id: str, + test_name: str, + ): + logger.info( + f"Starting docker launcher for model {model_id} and test {test_name}" + ) + + port = 8080 + + client = docker.from_env() + + container_name = f"tei-neuron-test-{test_name.replace('/', '-').replace('_', '-')}" + + try: + container = client.containers.get(container_name) + logger.info( + f"Stopping existing container {container_name} for test {test_name}" + ) + container.stop() + container.wait() + except NotFound: + pass + except Exception as e: + logger.error(f"Error handling existing container: {str(e)}") + + tei_args = TEST_CONFIGS[test_name]["args"].copy() + + # add model_id to tei args + tei_args.append("--model-id") + tei_args.append(model_id) + + env = BASE_ENV.copy() + env["HF_TOKEN"] = os.getenv("HF_TOKEN") + + # Add env config that is defined in the fixture parameter + if "env_config" in TEST_CONFIGS[test_name]: + env.update(TEST_CONFIGS[test_name]["env_config"].copy()) + + volumes = [f"{DOCKER_VOLUME}:/data"] if DOCKER_VOLUME else [] + logger.debug(f"Using volume {volumes}") + + try: + logger.info(f"Creating container with name {container_name}") + + # Build run arguments - use privileged mode for Neuron OCI hook + run_args = NEURON_RUN_ARGS.copy() + + container = client.containers.run( + DOCKER_IMAGE, + command=tei_args, + name=container_name, + environment=env, + detach=True, + volumes=volumes if volumes else None, + ports={"80/tcp": port}, + **run_args, + ) + + logger.info(f"Container {container_name} started successfully") + + # Start log streaming in a background thread + log_thread = threading.Thread( + target=stream_container_logs, + args=(container, test_name), + daemon=True, # This ensures the thread will be killed when the main program exits + ) + log_thread.start() + + # Add a small delay to allow container to initialize + time.sleep(2) + + # Check container status after creation + status = container.status + logger.debug(f"Initial container status: {status}") + if status not in ["running", "created"]: + logs = container.logs().decode("utf-8") + logger.error(f"Container failed to start properly. Logs:\n{logs}") + + yield ContainerLauncherHandle(client, container.name, port) + + except Exception as e: + logger.error(f"Error starting container: {str(e)}") + # Get full traceback for debugging + import traceback + + logger.error(f"Full traceback:\n{traceback.format_exc()}") + raise + finally: + try: + container = client.containers.get(container_name) + logger.info(f"Stopping container {container_name}") + container.stop() + container.wait() + + container_output = container.logs().decode("utf-8") + print(container_output, file=sys.stderr) + + container.remove() + logger.info(f"Container {container_name} removed successfully") + except NotFound: + pass + except Exception as e: + logger.warning(f"Error cleaning up container: {str(e)}") + + return docker_launcher diff --git a/integration_tests/neuron/test_embed.py b/integration_tests/neuron/test_embed.py new file mode 100644 index 000000000..4ca4aadb9 --- /dev/null +++ b/integration_tests/neuron/test_embed.py @@ -0,0 +1,185 @@ +from typing import Any, Dict, Generator +from _pytest.fixtures import SubRequest + +import pytest +import pytest_asyncio +import numpy as np + + +# Test configurations for Neuron backend +TEST_CONFIGS = { + # On-the-fly Neuron compilation + "sentence-transformers/all-MiniLM-L6-v2": { + "model_id": "sentence-transformers/all-MiniLM-L6-v2", + "input": "What is Deep Learning?", + "batch_inputs": [ + "What is Deep Learning?", + "How does machine learning work?", + "Tell me about neural networks.", + ], + "expected_output_prefix": None, + "args": [ + "--dtype", "float32", + "--max-batch-requests", "1", + ], + "env_config": { + "MAX_WARMUP_SEQUENCE_LENGTH": "512", + }, + }, + "BAAI/bge-base-en-v1.5": { + "model_id": "BAAI/bge-base-en-v1.5", + "input": "What is Deep Learning?", + "batch_inputs": [ + "What is Deep Learning?", + "How does machine learning work?", + "Tell me about neural networks.", + ], + "expected_output_prefix": None, + "args": [ + "--dtype", "float32", + "--max-batch-requests", "1", + ], + "env_config": { + "MAX_WARMUP_SEQUENCE_LENGTH": "512", + }, + }, + # Pre-compiled Neuron model + "optimum/bge-base-en-v1.5-neuronx": { + "model_id": "optimum/bge-base-en-v1.5-neuronx", + "input": "What is Deep Learning?", + "batch_inputs": [ + "What is Deep Learning?", + "How does machine learning work?", + "Tell me about neural networks.", + ], + "expected_output_prefix": None, + "args": [ + "--dtype", "float32", + "--max-batch-requests", "1", + ], + "env_config": { + "MAX_WARMUP_SEQUENCE_LENGTH": "512", + }, + }, +} + + +@pytest.fixture(scope="module", params=TEST_CONFIGS.keys()) +def test_config(request: SubRequest) -> Dict[str, Any]: + """Fixture that provides model configurations for testing.""" + model_name = request.param + test_config = TEST_CONFIGS[model_name].copy() + test_config["test_name"] = model_name + return test_config + + +@pytest.fixture(scope="module") +def model_id(test_config: Dict[str, Any]) -> Generator[str, None, None]: + yield test_config["model_id"] + + +@pytest.fixture(scope="module") +def test_name(test_config: Dict[str, Any]) -> Generator[str, None, None]: + yield test_config["test_name"] + + +@pytest.fixture(scope="module") +def input_text(test_config: Dict[str, Any]) -> str: + return test_config["input"] + + +@pytest.fixture(scope="module") +def batch_inputs(test_config: Dict[str, Any]) -> list: + return test_config.get("batch_inputs", [test_config["input"]]) + + +@pytest.fixture(scope="module") +def expected_outputs(test_config: Dict[str, Any]) -> Dict[str, Any]: + return { + "expected_output_prefix": test_config.get("expected_output_prefix"), + } + + +@pytest.fixture(scope="function") +def tei_service(neuron_launcher, model_id: str, test_name: str): + with neuron_launcher(model_id, test_name) as tei_service: + yield tei_service + + +@pytest_asyncio.fixture(scope="function") +async def tei_client(tei_service): + # Neuron models may take longer to load due to compilation + await tei_service.health(600) # 10 minute timeout for Neuron compilation + return tei_service + + +@pytest.mark.asyncio +async def test_model_single_request( + tei_client, expected_outputs: Dict[str, Any], input_text: str +): + """Test single embedding request.""" + response = await tei_client.embed(input_text) + + # Verify response structure + assert isinstance(response, list), f"Expected list, got {type(response)}" + assert len(response) > 0, "Embedding should not be empty" + + response_array = np.array(response) + + # Check that values are numeric + assert response_array.dtype in [np.float32, np.float64, np.float16], \ + f"Expected float array, got {response_array.dtype}" + + # If expected output is provided, validate against it + expected_prefix = expected_outputs.get("expected_output_prefix") + if expected_prefix is not None: + expected_array = np.array(eval(expected_prefix) if isinstance(expected_prefix, str) else expected_prefix) + prefix_len = len(expected_array.flatten()) + response_flat = response_array.flatten()[:prefix_len] + + if not np.allclose(response_flat, expected_array.flatten(), rtol=1e-4, atol=1e-4): + print("\nExpected output (prefix):") + print(f"{expected_array.tolist()}") + print("\nReceived output (prefix):") + print(f"{response_flat.tolist()}") + raise AssertionError("Response array does not match expected array within tolerance") + + # Check embedding dimensions are reasonable (typically 384, 768, 1024, etc.) + embedding_dim = response_array.shape[-1] if response_array.ndim > 1 else len(response_array) + assert embedding_dim > 0, "Embedding dimension should be positive" + + print(f"Single request embedding shape: {response_array.shape}") + print(f"Embedding dimension: {embedding_dim}") + + +@pytest.mark.asyncio +async def test_model_batch_request(tei_client, batch_inputs: list): + """Test batch embedding request.""" + response = await tei_client.embed_batch(batch_inputs) + + # Verify response is a list of embeddings + assert isinstance(response, list), f"Expected list, got {type(response)}" + assert len(response) == len(batch_inputs), \ + f"Expected {len(batch_inputs)} embeddings, got {len(response)}" + + response_array = np.array(response) + print(f"Batch request response shape: {response_array.shape}") + + # Check each embedding + for i, embedding in enumerate(response): + assert isinstance(embedding, list), f"Embedding {i} should be a list" + assert len(embedding) > 0, f"Embedding {i} should not be empty" + + +@pytest.mark.asyncio +async def test_model_embedding_consistency(tei_client, input_text: str): + """Test that the same input produces consistent embeddings.""" + response1 = await tei_client.embed(input_text) + response2 = await tei_client.embed(input_text) + + array1 = np.array(response1) + array2 = np.array(response2) + + # Embeddings for the same input should be identical (or very close) + assert np.allclose(array1, array2, rtol=1e-4, atol=1e-4), \ + "Same input should produce consistent embeddings" diff --git a/router/Cargo.toml b/router/Cargo.toml index 381d611c0..605fa4dc3 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -86,6 +86,7 @@ metal = ["text-embeddings-backend/metal"] mkl = ["text-embeddings-backend/mkl", "dep:intel-mkl-src"] accelerate = ["text-embeddings-backend/accelerate"] python = ["text-embeddings-backend/python"] +python-neuron = ["text-embeddings-backend/python-neuron"] ort = ["text-embeddings-backend/ort"] candle = ["text-embeddings-backend/candle"] candle-cuda = ["candle", "text-embeddings-backend/flash-attn", "dep:cudarc"]