Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion autointent/_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def __init__(self, embedder_config: EmbedderConfig) -> None:
self.config = embedder_config

self.embedding_model = SentenceTransformer(
self.config.model_name, device=self.config.device, prompts=embedder_config.get_prompt_config()
self.config.model_name,
device=self.config.device,
prompts=embedder_config.get_prompt_config(),
trust_remote_code=self.config.trust_remote_code,
)

self._logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion autointent/_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
self.config = CrossEncoderConfig.from_search_config(cross_encoder_config)
self.cross_encoder = st.CrossEncoder(
self.config.model_name,
trust_remote_code=True,
trust_remote_code=self.config.trust_remote_code,
device=self.config.device,
max_length=self.config.tokenizer_config.max_length, # type: ignore[arg-type]
)
Expand Down
1 change: 1 addition & 0 deletions autointent/configs/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class HFModelConfig(BaseModel):
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
device: str | None = Field(None, description="Torch notation for CPU or CUDA.")
tokenizer_config: TokenizerConfig = Field(default_factory=TokenizerConfig)
trust_remote_code: bool = Field(False, description="Whether to trust the remote code when loading the model.")

@classmethod
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> Self:
Expand Down
1 change: 1 addition & 0 deletions autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def fit(

self._model = AutoModelForSequenceClassification.from_pretrained(
model_name,
trust_remote_code=self.classification_model_config.trust_remote_code,
num_labels=self._n_classes,
label2id=label2id,
id2label=id2label,
Expand Down
14 changes: 14 additions & 0 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@
"tokenizer_config": {
"$ref": "#/$defs/TokenizerConfig"
},
"trust_remote_code": {
"default": false,
"description": "Whether to trust the remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
},
"train_head": {
"default": false,
"description": "Whether to train the head of the model. If False, LogReg will be trained.",
Expand Down Expand Up @@ -122,6 +128,12 @@
"tokenizer_config": {
"$ref": "#/$defs/TokenizerConfig"
},
"trust_remote_code": {
"default": false,
"description": "Whether to trust the remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
},
"default_prompt": {
"anyOf": [
{
Expand Down Expand Up @@ -370,6 +382,7 @@
"padding": true,
"truncation": true
},
"trust_remote_code": false,
"default_prompt": null,
"classifier_prompt": null,
"cluster_prompt": null,
Expand All @@ -390,6 +403,7 @@
"padding": true,
"truncation": true
},
"trust_remote_code": false,
"train_head": false
}
},
Expand Down
3 changes: 3 additions & 0 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_pipeline_callbacks(dataset):
"query_prompt": None,
"sts_prompt": None,
"use_cache": False,
"trust_remote_code": False,
},
"k": 1,
"weights": "uniform",
Expand Down Expand Up @@ -180,6 +181,7 @@ def test_pipeline_callbacks(dataset):
"query_prompt": None,
"sts_prompt": None,
"use_cache": False,
"trust_remote_code": False,
},
"k": 1,
"weights": "distance",
Expand Down Expand Up @@ -214,6 +216,7 @@ def test_pipeline_callbacks(dataset):
"query_prompt": None,
"sts_prompt": None,
"use_cache": False,
"trust_remote_code": False,
},
},
"module_name": "linear",
Expand Down
22 changes: 11 additions & 11 deletions tests/configs/test_combined_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ def test_invalid_optimizer_config_missing_field():
def test_invalid_optimizer_config_wrong_type():
"""Test that an invalid field type raises ValidationError."""
invalid_config = {
"node_type": "scoring",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "dnnc",
"cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list
"k": "wrong_type", # Should be a list of integers
"train_head": "true", # Should be a boolean, not a string
}
],
}
"node_type": "scoring",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "dnnc",
"cross_encoder_name": "cross-encoder/ms-marco-MiniLM-L-6-v2", # Should be a list
"k": "wrong_type", # Should be a list of integers
"train_head": "true", # Should be a boolean, not a string
}
],
}

with pytest.raises(TypeError):
NodeOptimizer(**invalid_config)
2 changes: 1 addition & 1 deletion tests/modules/scoring/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_bert_scorer_dump_load(dataset):

finally:
# Clean up
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error


def test_bert_prediction(dataset):
Expand Down
Loading