diff --git a/plexe/internal/common/utils/pydantic_utils.py b/plexe/internal/common/utils/pydantic_utils.py index f542cdf..b8fd1c1 100644 --- a/plexe/internal/common/utils/pydantic_utils.py +++ b/plexe/internal/common/utils/pydantic_utils.py @@ -67,19 +67,20 @@ def map_to_basemodel(name: str, schema: dict | Type[BaseModel]) -> Type[BaseMode # Handle both Dict[str, type] and Dict[str, str] formats annotated_schema = {} + type_mapping = { + "int": int, + "float": float, + "str": str, + "bool": bool, + "List[int]": List[int], + "List[float]": List[float], + "List[str]": List[str], + "List[bool]": List[bool], + } + for k, v in schema.items(): # If v is a string like "int", convert it to the actual type if isinstance(v, str): - type_mapping = { - "int": int, - "float": float, - "str": str, - "bool": bool, - "List[int]": List[int], - "List[float]": List[float], - "List[str]": List[str], - "List[bool]": List[bool], - } if v in type_mapping: annotated_schema[k] = (type_mapping[v], ...) else: diff --git a/plexe/models.py b/plexe/models.py index dfab651..eb9faaa 100644 --- a/plexe/models.py +++ b/plexe/models.py @@ -38,7 +38,7 @@ import uuid import warnings from datetime import datetime -from typing import Dict, List, Type, Any +from typing import Dict, List, Type, Any, Optional from deprecated import deprecated import pandas as pd @@ -93,8 +93,8 @@ class Model: def __init__( self, intent: str, - input_schema: Type[BaseModel] | Dict[str, type] = None, - output_schema: Type[BaseModel] | Dict[str, type] = None, + input_schema: Type[BaseModel] | Dict[str, type] = {}, + output_schema: Type[BaseModel] | Dict[str, type] = {}, distributed: bool = False, ): """ @@ -110,8 +110,10 @@ def __init__( # The model's identity is defined by these fields self.intent: str = intent - self.input_schema: Type[BaseModel] = map_to_basemodel("in", input_schema) if input_schema else None - self.output_schema: Type[BaseModel] = map_to_basemodel("out", output_schema) if output_schema else None + self.input_schema: Optional[Type[BaseModel]] = map_to_basemodel("in", input_schema) if input_schema else None + self.output_schema: Optional[Type[BaseModel]] = ( + map_to_basemodel("out", output_schema) if output_schema else None + ) self.training_data: Dict[str, Dataset] = dict() self.distributed: bool = distributed @@ -142,10 +144,10 @@ def build( self, datasets: List[pd.DataFrame | DatasetGenerator], provider: str | ProviderConfig = "openai/gpt-4o-mini", - timeout: int = None, - max_iterations: int = None, + timeout: int | None = None, + max_iterations: int | None = None, run_timeout: int = 1800, - callbacks: List[Callback] = None, + callbacks: List[Callback] | None = None, verbose: bool = False, # resume: bool = False, enable_checkpointing: bool = False,