Skip to content

Commit 7c20546

Browse files
committed
try to fix dynamic schema issues
1 parent cb24d83 commit 7c20546

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

src/autointent/_optimization_config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
from typing import Any
22

3-
from pydantic import BaseModel, PositiveInt
3+
from pydantic import BaseModel, Field, PositiveInt, field_validator
44

5-
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, HPOConfig, LoggingConfig
5+
from .configs import (
6+
CrossEncoderConfig,
7+
DataConfig,
8+
EmbedderConfig,
9+
HFModelConfig,
10+
HPOConfig,
11+
LoggingConfig,
12+
initialize_embedder_config,
13+
)
614

715

816
class OptimizationConfig(BaseModel):
@@ -20,7 +28,13 @@ class OptimizationConfig(BaseModel):
2028
logging_config: LoggingConfig = LoggingConfig()
2129
"""See tutorial on logging configuration."""
2230

23-
embedder_config: EmbedderConfig = EmbedderConfig()
31+
embedder_config: EmbedderConfig = Field(default_factory=lambda: initialize_embedder_config(None))
32+
33+
@field_validator("embedder_config", mode="before")
34+
@classmethod
35+
def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
36+
"""Validate and convert embedder config to proper type."""
37+
return initialize_embedder_config(v)
2438

2539
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
2640

src/autointent/modules/embedding/_logreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
def from_context(
6868
cls,
6969
context: Context,
70-
embedder_config: EmbedderConfig | str | None = None,
70+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
7171
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
7272
cv: PositiveInt = 3,
7373
) -> "LogregAimedEmbedding":

src/autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
def from_context(
7272
cls,
7373
context: Context,
74-
embedder_config: EmbedderConfig | str | None = None,
74+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
7575
k: PositiveInt = 10,
7676
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
7777
) -> "RetrievalAimedEmbedding":

src/autointent/modules/scoring/_description/bi_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def from_context(
7373
cls,
7474
context: Context,
7575
temperature: PositiveFloat = 1.0,
76-
embedder_config: EmbedderConfig | str | None = None,
76+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
7777
) -> "BiEncoderDescriptionScorer":
7878
"""Create a BiEncoderDescriptionScorer instance using a Context object.
7979

src/autointent/modules/scoring/_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def from_context(
7373
cls,
7474
context: Context,
7575
cv: PositiveInt = 3,
76-
embedder_config: EmbedderConfig | str | None = None,
76+
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
7777
) -> "LinearScorer":
7878
"""Create a LinearScorer instance using a Context object.
7979

0 commit comments

Comments
 (0)