diff --git a/autointent/_embedder.py b/autointent/_embedder.py index 4de8f46fa..c63e8a31d 100644 --- a/autointent/_embedder.py +++ b/autointent/_embedder.py @@ -73,7 +73,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None: self.config = embedder_config self.embedding_model = SentenceTransformer( - self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config() + self.config.model_name, + device=self.config.device, + prompts=embedder_config.get_prompt_config(), + trust_remote_code=self.config.trust_remote_code, ) self._logger = logging.getLogger(__name__) diff --git a/autointent/_ranker.py b/autointent/_ranker.py index 43774d5dd..19b10e634 100644 --- a/autointent/_ranker.py +++ b/autointent/_ranker.py @@ -111,7 +111,7 @@ def __init__( self.config = CrossEncoderConfig.from_search_config(cross_encoder_config) self.cross_encoder = st.CrossEncoder( self.config.model_name, - trust_remote_code=True, + trust_remote_code=self.config.trust_remote_code, device=self.config.device, max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type] ) diff --git a/autointent/configs/_transformers.py b/autointent/configs/_transformers.py index bfc88839d..77948737d 100644 --- a/autointent/configs/_transformers.py +++ b/autointent/configs/_transformers.py @@ -19,6 +19,7 @@ class HFModelConfig(BaseModel): batch_size: PositiveInt = Field(32, description="Batch size for model inference.") device: str | None = Field(None, description="Torch notation for CPU or CUDA.") tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig) + trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.") @classmethod def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self: diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py index d292fea1c..49965d653 100644 --- a/autointent/modules/scoring/_bert.py +++ b/autointent/modules/scoring/_bert.py @@ -89,6 +89,7 @@ def fit( self._model = AutoModelForSequenceClassification.from_pretrained( model_name, + trust_remote_code=self.classification_model_config.trust_remote_code, num_labels=self._n_classes, label2id=label2id, id2label=id2label, diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index 192e86c5c..cb05e889b 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -32,6 +32,12 @@ "tokenizer_config": { "$ref": "#/$defs/TokenizerConfig" }, + "trust_remote_code": { + "default": false, + "description": "Whether to trust the remote code when loading the model.", + "title": "Trust Remote Code", + "type": "boolean" + }, "train_head": { "default": false, "description": "Whether to train the head of the model. If False, LogReg will be trained.", @@ -122,6 +128,12 @@ "tokenizer_config": { "$ref": "#/$defs/TokenizerConfig" }, + "trust_remote_code": { + "default": false, + "description": "Whether to trust the remote code when loading the model.", + "title": "Trust Remote Code", + "type": "boolean" + }, "default_prompt": { "anyOf": [ { @@ -370,6 +382,7 @@ "padding": true, "truncation": true }, + "trust_remote_code": false, "default_prompt": null, "classifier_prompt": null, "cluster_prompt": null, @@ -390,6 +403,7 @@ "padding": true, "truncation": true }, + "trust_remote_code": false, "train_head": false } }, diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 30d2b0d3c..e121aa29a 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -146,6 +146,7 @@ def test_pipeline_callbacks(dataset): "query_prompt": None, "sts_prompt": None, "use_cache": False, + "trust_remote_code": False, }, "k": 1, "weights": "uniform", @@ -180,6 +181,7 @@ def test_pipeline_callbacks(dataset): "query_prompt": None, "sts_prompt": None, "use_cache": False, + "trust_remote_code": False, }, "k": 1, "weights": "distance", @@ -214,6 +216,7 @@ def test_pipeline_callbacks(dataset): "query_prompt": None, "sts_prompt": None, "use_cache": False, + "trust_remote_code": False, }, }, "module_name": "linear", diff --git a/tests/configs/test_combined_config.py b/tests/configs/test_combined_config.py index 81312c7f0..345cf082d 100644 --- a/tests/configs/test_combined_config.py +++ b/tests/configs/test_combined_config.py @@ -75,17 +75,17 @@ def test_invalid_optimizer_config_missing_field(): def test_invalid_optimizer_config_wrong_type(): """Test that an invalid field type raises ValidationError.""" invalid_config = { - "node_type": "scoring", - "target_metric": "scoring_roc_auc", - "search_space": [ - { - "module_name": "dnnc", - "cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list - "k": "wrong_type", # Should be a list of integers - "train_head": "true", # Should be a boolean, not a string - } - ], - } + "node_type": "scoring", + "target_metric": "scoring_roc_auc", + "search_space": [ + { + "module_name": "dnnc", + "cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list + "k": "wrong_type", # Should be a list of integers + "train_head": "true", # Should be a boolean, not a string + } + ], + } with pytest.raises(TypeError): NodeOptimizer(**invalid_config) diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index 4512cd1fd..a2b7a3c5d 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -51,7 +51,7 @@ def test_bert_scorer_dump_load(dataset): finally: # Clean up - shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error + shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error def test_bert_prediction(dataset):