-
Notifications
You must be signed in to change notification settings - Fork 370
Expand file tree
/
Copy path__init__.py
More file actions
179 lines (152 loc) · 6.7 KB
/
__init__.py
File metadata and controls
179 lines (152 loc) · 6.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import torch
from loguru import logger
from pathlib import Path
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig
from text_embeddings_server.models.model import Model
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron
__all__ = ["Model"]
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
"true",
"1",
]
# Flash Attention models - only available when flash_attn is installed
FLASH_ATTENTION = True
FlashBert = None
FlashJinaBert = None
FlashMistral = None
FlashQwen3 = None
try:
from text_embeddings_server.models.flash_bert import FlashBert
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
# Disable gradients
torch.set_grad_enabled(False)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashBert)
# Neuron models - only import when on Neuron device to avoid unnecessary dependencies
NeuronSentenceTransformersModel = None
NeuronClassificationModel = None
NeuronMaskedLMModel = None
create_neuron_model = None
if is_neuron():
try:
from text_embeddings_server.models.neuron_models import (
NeuronSentenceTransformersModel,
NeuronClassificationModel,
NeuronMaskedLMModel,
create_neuron_model,
)
except ImportError as e:
logger.warning(f"Could not import Neuron models: {e}")
def wrap_model_if_hpu(model_handle, device):
"""Wrap the model in HPU graph if the device is HPU."""
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model_handle.model = wrap_in_hpu_graph(
model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE
)
return model_handle
def create_model(model_class, model_path, device, datatype, pool="cls"):
"""Create a model instance and wrap it if needed."""
model_handle = model_class(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return wrap_model_if_hpu(model_handle, device)
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
datatype = torch.float32
elif dtype == "float16":
datatype = torch.float16
elif dtype == "bfloat16":
datatype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
device = get_device()
logger.info(f"backend device: {device}")
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
# Neuron cases - use optimum-neuron for all supported model types
if is_neuron():
logger.info(f"Neuron device detected, using optimum-neuron backend for model type: {config.model_type}")
try:
return create_neuron_model(
model_path=model_path,
device=device,
dtype=datatype,
pool=pool,
trust_remote=TRUST_REMOTE_CODE,
config=config,
)
except Exception as e:
logger.warning(f"Failed to load model with optimum-neuron: {e}")
logger.warning("Falling back to default model loading path")
# Fall through to default model loading
if (
FlashJinaBert is not None
and hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return create_model(FlashJinaBert, model_path, device, datatype)
if config.model_type == "bert":
config: BertConfig
if (
use_ipex()
or device.type in ["cuda", "hpu"]
and config.position_embedding_type == "absolute"
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(
MaskedLanguageModel, model_path, device, datatype, pool
)
return create_model(DefaultModel, model_path, device, datatype, pool)
try:
return create_model(FlashBert, model_path, device, datatype)
except FileNotFoundError:
logger.info(
"Do not have safetensors file for this model, use default transformers model path instead"
)
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "mistral" and device.type == "hpu" and FlashMistral is not None:
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "qwen3" and device.type == "hpu" and FlashQwen3 is not None:
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)