diff --git a/ldai/client.py b/ldai/client.py index e2d53a2..1eceeac 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -1,20 +1,27 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple import chevron -from dataclasses_json import dataclass_json from ldclient import Context from ldclient.client import LDClient from ldai.tracker import LDAIConfigTracker -@dataclass_json @dataclass class LDMessage: role: Literal['system', 'user', 'assistant'] content: str + def to_dict(self) -> dict: + """ + Render the given message as a dictionary object. + """ + return { + 'role': self.role, + 'content': self.content, + } + class ModelConfig: """ @@ -62,6 +69,16 @@ def get_custom(self, key: str) -> Any: return self._custom.get(key) + def to_dict(self) -> dict: + """ + Render the given model config as a dictionary object. + """ + return { + 'id': self._id, + 'parameters': self._parameters, + 'custom': self._custom, + } + class ProviderConfig: """ @@ -78,14 +95,34 @@ def id(self) -> str: """ return self._id + def to_dict(self) -> dict: + """ + Render the given provider config as a dictionary object. + """ + return { + 'id': self._id, + } + +@dataclass(frozen=True) class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None): - self.tracker = tracker - self.enabled = enabled - self.model = model - self.messages = messages - self.provider = provider + enabled: Optional[bool] = None + model: Optional[ModelConfig] = None + messages: Optional[List[LDMessage]] = None + provider: Optional[ProviderConfig] = None + + def to_dict(self) -> dict: + """ + Render the given default values as an AIConfig-compatible dictionary object. + """ + return { + '_ldMeta': { + 'enabled': self.enabled or False, + }, + 'model': self.model.to_dict() if self.model else None, + 'messages': [message.to_dict() for message in self.messages] if self.messages else None, + 'provider': self.provider.to_dict() if self.provider else None, + } class LDAIClient: @@ -100,7 +137,7 @@ def config( context: Context, default_value: AIConfig, variables: Optional[Dict[str, Any]] = None, - ) -> AIConfig: + ) -> Tuple[AIConfig, LDAIConfigTracker]: """ Get the value of a model configuration. @@ -108,9 +145,9 @@ def config( :param context: The context to evaluate the model configuration in. :param default_value: The default value of the model configuration. :param variables: Additional variables for the model configuration. - :return: The value of the model configuration. + :return: The value of the model configuration along with a tracker used for gathering metrics. """ - variation = self.client.variation(key, context, default_value) + variation = self.client.variation(key, context, default_value.to_dict()) all_variables = {} if variables: @@ -146,20 +183,23 @@ def config( custom=custom ) + tracker = LDAIConfigTracker( + self.client, + variation.get('_ldMeta', {}).get('versionKey', ''), + key, + context, + ) + enabled = variation.get('_ldMeta', {}).get('enabled', False) - return AIConfig( - tracker=LDAIConfigTracker( - self.client, - variation.get('_ldMeta', {}).get('versionKey', ''), - key, - context, - ), + config = AIConfig( enabled=bool(enabled), model=model, messages=messages, provider=provider_config, ) + return config, tracker + def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: """ Interpolate the template with the given variables. diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 6f97a4d..593a037 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -3,7 +3,6 @@ from ldclient.integrations.test_data import TestData from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig -from ldai.tracker import LDAIConfigTracker @pytest.fixture @@ -98,11 +97,6 @@ def client(td: TestData) -> LDClient: return LDClient(config=config) -@pytest.fixture -def tracker(client: LDClient) -> LDAIConfigTracker: - return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key')) - - @pytest.fixture def ldai_client(client: LDClient) -> LDAIClient: return LDAIClient(client) @@ -125,17 +119,38 @@ def test_model_config_handles_custom(): assert model.get_custom('id') is None -def test_model_config_interpolation(ldai_client: LDAIClient, tracker): +def test_uses_default_on_invalid_flag(ldai_client: LDAIClient): + context = Context.create('user-key') + default_value = AIConfig( + enabled=True, + model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}), + messages=[LDMessage(role='system', content='Hello, {{name}}!')], + ) + variables = {'name': 'World'} + + config, _ = ldai_client.config('missing-flag', context, default_value, variables) + + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' + assert config.enabled is True + + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 + + +def test_model_config_interpolation(ldai_client: LDAIClient): context = Context.create('user-key') default_value = AIConfig( - tracker=tracker, enabled=True, model=ModelConfig('fakeModel'), messages=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} - config = ldai_client.config('model-config', context, default_value, variables) + config, _ = ldai_client.config('model-config', context, default_value, variables) assert config.messages is not None assert len(config.messages) > 0 @@ -148,11 +163,11 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 4096 -def test_model_config_no_variables(ldai_client: LDAIClient, tracker): +def test_model_config_no_variables(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('model-config', context, default_value, {}) + config, _ = ldai_client.config('model-config', context, default_value, {}) assert config.messages is not None assert len(config.messages) > 0 @@ -165,23 +180,23 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 4096 -def test_provider_config_handling(ldai_client: LDAIClient, tracker): +def test_provider_config_handling(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.config('model-config', context, default_value, variables) + config, _ = ldai_client.config('model-config', context, default_value, variables) assert config.provider is not None assert config.provider.id == 'fakeProvider' -def test_context_interpolation(ldai_client: LDAIClient, tracker): +def test_context_interpolation(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.config( + config, _ = ldai_client.config( 'ctx-interpolation', context, default_value, variables ) @@ -197,12 +212,12 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to' -def test_model_config_multiple(ldai_client: LDAIClient, tracker): +def test_model_config_multiple(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} - config = ldai_client.config( + config, _ = ldai_client.config( 'multiple-messages', context, default_value, variables ) @@ -218,11 +233,11 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 8192 -def test_model_config_disabled(ldai_client: LDAIClient, tracker): +def test_model_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('off-config', context, default_value, {}) + config, _ = ldai_client.config('off-config', context, default_value, {}) assert config.model is not None assert config.enabled is False @@ -231,11 +246,11 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') is None -def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): +def test_model_initial_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('initial-config-disabled', context, default_value, {}) + config, _ = ldai_client.config('initial-config-disabled', context, default_value, {}) assert config.enabled is False assert config.model is None @@ -243,11 +258,11 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): assert config.provider is None -def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): +def test_model_initial_config_enabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('initial-config-enabled', context, default_value, {}) + config, _ = ldai_client.config('initial-config-enabled', context, default_value, {}) assert config.enabled is True assert config.model is None diff --git a/pyproject.toml b/pyproject.toml index f301084..943049e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ exclude = [ python = ">=3.8,<4" launchdarkly-server-sdk = ">=9.4.0" chevron = "=0.14.0" -dataclasses-json = "^0.6.7" [tool.poetry.group.dev.dependencies] diff --git a/release-please-config.json b/release-please-config.json index a7ac352..78df6d7 100644 --- a/release-please-config.json +++ b/release-please-config.json @@ -3,7 +3,7 @@ ".": { "release-type": "python", "versioning": "default", - "release-as": "0.3.0", + "bump-minor-pre-major": true, "include-v-in-tag": false, "extra-files": ["ldai/__init__.py", "PROVENANCE.md"], "include-component-in-tag": false