Skip to content

Commit 5566e76

Browse files
bahdan111Bahdan Kapionkin
andauthored
feat: add GPT-5 model family support #92 (#140)
Co-authored-by: Bahdan Kapionkin <bahdan_kapionkin@epam.com>
1 parent 378a59c commit 5566e76

File tree

6 files changed

+89
-23
lines changed

6 files changed

+89
-23
lines changed

statgpt/common/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .llm_models import EmbeddingModelsEnum, LLMModelsEnum
1+
from .llm_models import EmbeddingModelsEnum, LLMModelsEnum, ReasoningEffortEnum, VerbosityEnum
22
from .logging import LoggingConfig, logger, multiline_logger
33
from .versions import Versions
44

statgpt/common/config/llm_models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@ class EmbeddingModelsEnum(StrEnum):
66
TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large"
77

88

9+
class ReasoningEffortEnum(StrEnum):
10+
"""Reasoning effort levels for GPT-5 models."""
11+
12+
NONE = "none"
13+
"""No reasoning mode - standard inference."""
14+
MINIMAL = "minimal"
15+
LOW = "low"
16+
MEDIUM = "medium"
17+
HIGH = "high"
18+
XHIGH = "xhigh"
19+
20+
21+
class VerbosityEnum(StrEnum):
22+
"""Output verbosity levels for GPT-5 models."""
23+
24+
LOW = "low"
25+
MEDIUM = "medium"
26+
HIGH = "high"
27+
28+
929
class LLMModelsEnum(StrEnum):
1030
# Gemini models
1131
GEMINI_2_0_FLASH_LITE_001 = "gemini-2.0-flash-lite-001"
@@ -30,6 +50,11 @@ class LLMModelsEnum(StrEnum):
3050
GPT_4_1_MINI_2025_04_14 = "gpt-4.1-mini-2025-04-14"
3151
GPT_4_1_NANO_2025_04_14 = "gpt-4.1-nano-2025-04-14"
3252

53+
# GPT-5 models
54+
GPT_5_MINI_2025_08_07 = "gpt-5-mini-2025-08-07"
55+
GPT_5_1_2025_11_13 = "gpt-5.1-2025-11-13"
56+
GPT_5_2_2025_12_11 = "gpt-5.2-2025-12-11"
57+
3358
@property
3459
def deployment_id(self) -> str:
3560
return os.getenv(f"LLM_MODELS_{self.name}", self.value)
@@ -42,3 +67,12 @@ def is_gpt_41_family(self) -> bool:
4267
LLMModelsEnum.GPT_4_1_MINI_2025_04_14,
4368
LLMModelsEnum.GPT_4_1_NANO_2025_04_14,
4469
}
70+
71+
@property
72+
def is_gpt_5_family(self) -> bool:
73+
"""Check if the model belongs to the GPT-5 family."""
74+
return self in {
75+
LLMModelsEnum.GPT_5_MINI_2025_08_07,
76+
LLMModelsEnum.GPT_5_1_2025_11_13,
77+
LLMModelsEnum.GPT_5_2_2025_12_11,
78+
}

statgpt/common/models/models.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Channel(DefaultBase):
3232
title: Mapped[str]
3333
description: Mapped[str]
3434
deployment_id: Mapped[str] = mapped_column(unique=True)
35-
llm_model: Mapped[str] = mapped_column(default=langchain_settings.default_model.value)
35+
llm_model: Mapped[str] = mapped_column(default=langchain_settings.embedding_default_model.value)
3636
details: Mapped[dict[str, Any]] = mapped_column(type_=postgresql.JSONB)
3737

3838
# ~~~~~ Relationships ~~~~~
@@ -150,15 +150,17 @@ class ChannelDatasetVersion(DefaultBase):
150150
__tablename__ = "channel_dataset_versions"
151151
__table_args__ = (
152152
UniqueConstraint(
153-
'channel_dataset_id', 'version', name='uix_unique_version_for_channel_dataset'
153+
"channel_dataset_id",
154+
"version",
155+
name="uix_unique_version_for_channel_dataset",
154156
),
155157
)
156158

157159
channel_dataset_id: Mapped[int] = mapped_column(ForeignKey("channel_datasets.id"))
158160
version: Mapped[int] = mapped_column(default=0) # will be auto-incremented by trigger
159161
preprocessing_status: Mapped[PreprocessingStatusEnum]
160162
pointer_to: Mapped[int | None] = mapped_column(
161-
ForeignKey("channel_dataset_versions.id", ondelete='SET NULL'), default=None
163+
ForeignKey("channel_dataset_versions.id", ondelete="SET NULL"), default=None
162164
)
163165

164166
creation_reason: Mapped[str]
@@ -181,14 +183,14 @@ class ChannelDatasetVersion(DefaultBase):
181183
channel_dataset: Mapped[ChannelDataset] = relationship(back_populates="versions")
182184
pointer = relationship(
183185
"ChannelDatasetVersion",
184-
remote_side='ChannelDatasetVersion.id',
186+
remote_side="ChannelDatasetVersion.id",
185187
back_populates="pointing_versions",
186188
cascade="all",
187189
passive_deletes=True,
188190
)
189191
pointing_versions = relationship(
190192
"ChannelDatasetVersion",
191-
remote_side='ChannelDatasetVersion.pointer_to',
193+
remote_side="ChannelDatasetVersion.pointer_to",
192194
back_populates="pointer",
193195
cascade="all, delete-orphan",
194196
passive_deletes=True,

statgpt/common/schemas/model_config.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
from pydantic import Field
1+
from pydantic import Field, model_validator
22

3-
from statgpt.common.config import EmbeddingModelsEnum, LLMModelsEnum
3+
from statgpt.common.config import (
4+
EmbeddingModelsEnum,
5+
LLMModelsEnum,
6+
ReasoningEffortEnum,
7+
VerbosityEnum,
8+
)
49
from statgpt.common.settings.langchain import langchain_settings
510

611
from .base import BaseYamlModel
@@ -10,7 +15,8 @@ class BaseModelConfig(BaseYamlModel):
1015
"""Base config for LLM and embeddings models configs."""
1116

1217
api_version: str = Field(
13-
default=langchain_settings.default_api_version, description="API version for the model"
18+
default=langchain_settings.default_api_version,
19+
description="API version for the model",
1420
)
1521

1622

@@ -30,16 +36,47 @@ class LLMModelConfig(BaseModelConfig):
3036
default=langchain_settings.default_model,
3137
description="The deployment of the model in DIAL",
3238
)
33-
temperature: float = Field(
39+
temperature: float | None = Field(
3440
default=langchain_settings.default_temperature,
3541
description=(
3642
"The temperature of the model. 0.0 means deterministic output, higher values mean more"
37-
" randomness."
43+
" randomness. Note: For reasoning models (except reasoning_effort=none) should be set to 1"
3844
),
3945
)
4046
seed: int | None = Field(
4147
default=langchain_settings.default_seed,
4248
description=(
43-
"The seed of the model. If set, the model will produce the same output for the same input."
49+
"The seed of the model. If set, the model will produce the same output for the same input. "
50+
"Note: Not supported by GPT-5 models."
4451
),
4552
)
53+
reasoning_effort: ReasoningEffortEnum | None = Field(
54+
default=None,
55+
description=(
56+
"Reasoning effort level for GPT-5 models. "
57+
"Supports: none, minimal, low, medium, high, xhigh. "
58+
"All models before gpt-5.1 default to medium reasoning effort, and do not support none."
59+
),
60+
)
61+
verbosity: VerbosityEnum | None = Field(
62+
default=None,
63+
description="Output verbosity for GPT-5 models (low/medium/high).",
64+
)
65+
66+
@model_validator(mode="after")
67+
def _validate_model_family_params(self) -> "LLMModelConfig":
68+
if self.deployment.is_gpt_5_family:
69+
if self.seed is not None:
70+
raise ValueError("seed is not supported for GPT-5 models")
71+
if self.reasoning_effort is None:
72+
raise ValueError("reasoning_effort is required for GPT-5 models")
73+
if self.reasoning_effort is not ReasoningEffortEnum.NONE and self.temperature != 1:
74+
raise ValueError(
75+
"temperature must be set to 1 when reasoning_effort is enabled for GPT-5 models"
76+
)
77+
else:
78+
if self.reasoning_effort is not None:
79+
raise ValueError("reasoning_effort is only supported for GPT-5 models")
80+
if self.verbosity is not None:
81+
raise ValueError("verbosity is only supported for GPT-5 models")
82+
return self

statgpt/common/settings/langchain.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from typing import Optional
2-
3-
from langchain import globals as lc_globals
1+
from langchain_core import globals as lc_globals
42
from pydantic import Field
53
from pydantic_settings import BaseSettings, SettingsConfigDict
64

@@ -35,7 +33,7 @@ class LangChainSettings(BaseSettings):
3533
description="Default API version for Azure OpenAI",
3634
)
3735

38-
default_seed: Optional[int] = Field(
36+
default_seed: int | None = Field(
3937
default=None,
4038
description="Default seed for reproducible outputs",
4139
)

statgpt/common/utils/models.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def get_chat_model(
1515
model_config: LLMModelConfig,
1616
azure_endpoint: str = dial_settings.url,
1717
timeout: httpx.Timeout | None = None,
18-
**kwargs,
1918
) -> AzureChatOpenAI:
20-
# default params
2119
if not isinstance(api_key, SecretStr):
2220
api_key = SecretStr(api_key)
2321
if not timeout:
@@ -26,13 +24,12 @@ def get_chat_model(
2624
azure_endpoint=azure_endpoint,
2725
api_version=model_config.api_version,
2826
azure_deployment=model_config.deployment.deployment_id,
29-
temperature=model_config.temperature,
30-
seed=model_config.seed,
3127
max_retries=10,
3228
api_key=api_key, # since we use SecretStr, it won't be logged
3329
timeout=timeout, # timeouts are crucial!
3430
)
35-
params.update(kwargs) # update default params
31+
32+
params.update(model_config.model_dump(mode="json", exclude_none=True, exclude={"deployment"}))
3633

3734
if model_config.deployment.is_gpt_41_family:
3835
callback = BrokenResponseInterceptor(regex_pattern=r'\s{5,}')
@@ -49,7 +46,6 @@ def get_embeddings_model(
4946
api_key: str | SecretStr,
5047
model_config: EmbeddingsModelConfig,
5148
azure_endpoint: str = dial_settings.url,
52-
**kwargs,
5349
) -> AzureOpenAIEmbeddings:
5450
if not isinstance(api_key, SecretStr):
5551
api_key = SecretStr(api_key)
@@ -60,7 +56,6 @@ def get_embeddings_model(
6056
max_retries=10,
6157
api_key=api_key, # since we use SecretStr, it won't be logged
6258
)
63-
params.update(kwargs) # update default params
6459
api_key_log = f'{api_key.get_secret_value()[:3]}*****{api_key.get_secret_value()[-2:]}'
6560
logger.info(
6661
f'creating langchain embeddings with the following params: {params}, Api key: {api_key_log}'

0 commit comments

Comments
 (0)