1
+ import os
1
2
import torch
2
3
3
4
from loguru import logger
13
14
14
15
__all__ = ["Model" ]
15
16
17
+ TRUST_REMOTE_CODE = os .getenv ("TRUST_REMOTE_CODE" , "false" ).lower () in ["true" , "1" ]
16
18
# Disable gradients
17
19
torch .set_grad_enabled (False )
18
20
@@ -40,7 +42,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
40
42
device = get_device ()
41
43
logger .info (f"backend device: { device } " )
42
44
43
- config = AutoConfig .from_pretrained (model_path )
45
+ config = AutoConfig .from_pretrained (model_path , trust_remote_code = TRUST_REMOTE_CODE )
44
46
if config .model_type == "bert" :
45
47
config : BertConfig
46
48
if (
@@ -51,12 +53,22 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
51
53
and FLASH_ATTENTION
52
54
):
53
55
if pool != "cls" :
54
- return DefaultModel (model_path , device , datatype , pool )
56
+ return DefaultModel (
57
+ model_path , device , datatype , pool , trust_remote = TRUST_REMOTE_CODE
58
+ )
55
59
return FlashBert (model_path , device , datatype )
56
60
if config .architectures [0 ].endswith ("Classification" ):
57
- return ClassificationModel (model_path , device , datatype )
61
+ return ClassificationModel (
62
+ model_path , device , datatype , trust_remote = TRUST_REMOTE_CODE
63
+ )
58
64
else :
59
- return DefaultModel (model_path , device , datatype , pool )
65
+ return DefaultModel (
66
+ model_path ,
67
+ device ,
68
+ datatype ,
69
+ pool ,
70
+ trust_remote = TRUST_REMOTE_CODE ,
71
+ )
60
72
else :
61
73
if device .type == "hpu" :
62
74
from habana_frameworks .torch .hpu import wrap_in_hpu_graph
@@ -66,13 +78,35 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
66
78
67
79
adapt_transformers_to_gaudi ()
68
80
if config .architectures [0 ].endswith ("Classification" ):
69
- model_handle = ClassificationModel (model_path , device , datatype )
81
+ model_handle = ClassificationModel (
82
+ model_path ,
83
+ device ,
84
+ datatype ,
85
+ trust_remote = TRUST_REMOTE_CODE ,
86
+ )
70
87
else :
71
- model_handle = DefaultModel (model_path , device , datatype , pool )
88
+ model_handle = DefaultModel (
89
+ model_path ,
90
+ device ,
91
+ datatype ,
92
+ pool ,
93
+ trust_remote = TRUST_REMOTE_CODE ,
94
+ )
72
95
model_handle .model = wrap_in_hpu_graph (model_handle .model )
73
96
return model_handle
74
97
elif use_ipex ():
75
98
if config .architectures [0 ].endswith ("Classification" ):
76
- return ClassificationModel (model_path , device , datatype )
99
+ return ClassificationModel (
100
+ model_path ,
101
+ device ,
102
+ datatype ,
103
+ trust_remote = TRUST_REMOTE_CODE ,
104
+ )
77
105
else :
78
- return DefaultModel (model_path , device , datatype , pool )
106
+ return DefaultModel (
107
+ model_path ,
108
+ device ,
109
+ datatype ,
110
+ pool ,
111
+ trust_remote = TRUST_REMOTE_CODE ,
112
+ )
0 commit comments