Skip to content

Commit 8ba9535

Browse files
Add trust remote code (#185)
* lint * fix trust remote code * Update optimizer_config.schema.json * update fix trust remote code * fix test cllback --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b6a0c85 commit 8ba9535

File tree

8 files changed

+36
-14
lines changed

8 files changed

+36
-14
lines changed

autointent/_embedder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
7373
self.config = embedder_config
7474

7575
self.embedding_model = SentenceTransformer(
76-
self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config()
76+
self.config.model_name,
77+
device=self.config.device,
78+
prompts=embedder_config.get_prompt_config(),
79+
trust_remote_code=self.config.trust_remote_code,
7780
)
7881

7982
self._logger = logging.getLogger(__name__)

autointent/_ranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
self.config = CrossEncoderConfig.from_search_config(cross_encoder_config)
112112
self.cross_encoder = st.CrossEncoder(
113113
self.config.model_name,
114-
trust_remote_code=True,
114+
trust_remote_code=self.config.trust_remote_code,
115115
device=self.config.device,
116116
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
117117
)

autointent/configs/_transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class HFModelConfig(BaseModel):
1919
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
2020
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
2121
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
22+
trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.")
2223

2324
@classmethod
2425
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:

autointent/modules/scoring/_bert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def fit(
8989

9090
self._model = AutoModelForSequenceClassification.from_pretrained(
9191
model_name,
92+
trust_remote_code=self.classification_model_config.trust_remote_code,
9293
num_labels=self._n_classes,
9394
label2id=label2id,
9495
id2label=id2label,

docs/optimizer_config.schema.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
"tokenizer_config": {
3333
"$ref": "#/$defs/TokenizerConfig"
3434
},
35+
"trust_remote_code": {
36+
"default": false,
37+
"description": "Whether to trust the remote code when loading the model.",
38+
"title": "Trust Remote Code",
39+
"type": "boolean"
40+
},
3541
"train_head": {
3642
"default": false,
3743
"description": "Whether to train the head of the model. If False, LogReg will be trained.",
@@ -122,6 +128,12 @@
122128
"tokenizer_config": {
123129
"$ref": "#/$defs/TokenizerConfig"
124130
},
131+
"trust_remote_code": {
132+
"default": false,
133+
"description": "Whether to trust the remote code when loading the model.",
134+
"title": "Trust Remote Code",
135+
"type": "boolean"
136+
},
125137
"default_prompt": {
126138
"anyOf": [
127139
{
@@ -370,6 +382,7 @@
370382
"padding": true,
371383
"truncation": true
372384
},
385+
"trust_remote_code": false,
373386
"default_prompt": null,
374387
"classifier_prompt": null,
375388
"cluster_prompt": null,
@@ -390,6 +403,7 @@
390403
"padding": true,
391404
"truncation": true
392405
},
406+
"trust_remote_code": false,
393407
"train_head": false
394408
}
395409
},

tests/callback/test_callback.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def test_pipeline_callbacks(dataset):
146146
"query_prompt": None,
147147
"sts_prompt": None,
148148
"use_cache": False,
149+
"trust_remote_code": False,
149150
},
150151
"k": 1,
151152
"weights": "uniform",
@@ -180,6 +181,7 @@ def test_pipeline_callbacks(dataset):
180181
"query_prompt": None,
181182
"sts_prompt": None,
182183
"use_cache": False,
184+
"trust_remote_code": False,
183185
},
184186
"k": 1,
185187
"weights": "distance",
@@ -214,6 +216,7 @@ def test_pipeline_callbacks(dataset):
214216
"query_prompt": None,
215217
"sts_prompt": None,
216218
"use_cache": False,
219+
"trust_remote_code": False,
217220
},
218221
},
219222
"module_name": "linear",

tests/configs/test_combined_config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ def test_invalid_optimizer_config_missing_field():
7575
def test_invalid_optimizer_config_wrong_type():
7676
"""Test that an invalid field type raises ValidationError."""
7777
invalid_config = {
78-
"node_type": "scoring",
79-
"target_metric": "scoring_roc_auc",
80-
"search_space": [
81-
{
82-
"module_name": "dnnc",
83-
"cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list
84-
"k": "wrong_type", # Should be a list of integers
85-
"train_head": "true", # Should be a boolean, not a string
86-
}
87-
],
88-
}
78+
"node_type": "scoring",
79+
"target_metric": "scoring_roc_auc",
80+
"search_space": [
81+
{
82+
"module_name": "dnnc",
83+
"cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list
84+
"k": "wrong_type", # Should be a list of integers
85+
"train_head": "true", # Should be a boolean, not a string
86+
}
87+
],
88+
}
8989

9090
with pytest.raises(TypeError):
9191
NodeOptimizer(**invalid_config)

tests/modules/scoring/test_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_bert_scorer_dump_load(dataset):
5151

5252
finally:
5353
# Clean up
54-
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
54+
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
5555

5656

5757
def test_bert_prediction(dataset):

0 commit comments

Comments
 (0)