Skip to content

Commit 4ac772e

Browse files
authored
add TRUST_REMOTE_CODE param to python backend. (#485)
Signed-off-by: kaixuanliu <[email protected]>
1 parent 9561fc9 commit 4ac772e

File tree

3 files changed

+63
-12
lines changed

3 files changed

+63
-12
lines changed

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23

34
from loguru import logger
@@ -13,6 +14,7 @@
1314

1415
__all__ = ["Model"]
1516

17+
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
1618
# Disable gradients
1719
torch.set_grad_enabled(False)
1820

@@ -40,7 +42,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4042
device = get_device()
4143
logger.info(f"backend device: {device}")
4244

43-
config = AutoConfig.from_pretrained(model_path)
45+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
4446
if config.model_type == "bert":
4547
config: BertConfig
4648
if (
@@ -51,12 +53,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
5153
and FLASH_ATTENTION
5254
):
5355
if pool != "cls":
54-
return DefaultModel(model_path, device, datatype, pool)
56+
return DefaultModel(
57+
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
58+
)
5559
return FlashBert(model_path, device, datatype)
5660
if config.architectures[0].endswith("Classification"):
57-
return ClassificationModel(model_path, device, datatype)
61+
return ClassificationModel(
62+
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
63+
)
5864
else:
59-
return DefaultModel(model_path, device, datatype, pool)
65+
return DefaultModel(
66+
model_path,
67+
device,
68+
datatype,
69+
pool,
70+
trust_remote=TRUST_REMOTE_CODE,
71+
)
6072
else:
6173
if device.type == "hpu":
6274
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
@@ -66,13 +78,35 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6678

6779
adapt_transformers_to_gaudi()
6880
if config.architectures[0].endswith("Classification"):
69-
model_handle = ClassificationModel(model_path, device, datatype)
81+
model_handle = ClassificationModel(
82+
model_path,
83+
device,
84+
datatype,
85+
trust_remote=TRUST_REMOTE_CODE,
86+
)
7087
else:
71-
model_handle = DefaultModel(model_path, device, datatype, pool)
88+
model_handle = DefaultModel(
89+
model_path,
90+
device,
91+
datatype,
92+
pool,
93+
trust_remote=TRUST_REMOTE_CODE,
94+
)
7295
model_handle.model = wrap_in_hpu_graph(model_handle.model)
7396
return model_handle
7497
elif use_ipex():
7598
if config.architectures[0].endswith("Classification"):
76-
return ClassificationModel(model_path, device, datatype)
99+
return ClassificationModel(
100+
model_path,
101+
device,
102+
datatype,
103+
trust_remote=TRUST_REMOTE_CODE,
104+
)
77105
else:
78-
return DefaultModel(model_path, device, datatype, pool)
106+
return DefaultModel(
107+
model_path,
108+
device,
109+
datatype,
110+
pool,
111+
trust_remote=TRUST_REMOTE_CODE,
112+
)

backends/python/server/text_embeddings_server/models/classification_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313

1414

1515
class ClassificationModel(Model):
16-
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
17-
model = AutoModelForSequenceClassification.from_pretrained(model_path)
16+
def __init__(
17+
self,
18+
model_path: Path,
19+
device: torch.device,
20+
dtype: torch.dtype,
21+
trust_remote: bool = False,
22+
):
23+
model = AutoModelForSequenceClassification.from_pretrained(
24+
model_path, trust_remote_code=trust_remote
25+
)
1826
model = model.to(dtype).to(device)
1927

2028
self.hidden_size = model.config.hidden_size

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,18 @@
1515

1616
class DefaultModel(Model):
1717
def __init__(
18-
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
18+
self,
19+
model_path: Path,
20+
device: torch.device,
21+
dtype: torch.dtype,
22+
pool: str,
23+
trust_remote: bool = False,
1924
):
20-
model = AutoModel.from_pretrained(model_path).to(dtype).to(device)
25+
model = (
26+
AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote)
27+
.to(dtype)
28+
.to(device)
29+
)
2130
self.hidden_size = model.config.hidden_size
2231
self.pooling = Pooling(self.hidden_size, pooling_mode=pool)
2332

0 commit comments

Comments
 (0)