From af8f3d610e04c38209366e6f7b9cbe6f40205534 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Tue, 26 Nov 2024 15:56:05 -0500 Subject: [PATCH 1/2] fix: Distinguish AIConfig from DefaultAIConfig Consumers of this SDK are responsible for providing some default value to the `config` method. If that type is of `AIConfig`, customers have the option of providing a tracker instance that will never be useful. To help discourage this, we are making two separate types -- one to represent the default ai config, and one to represent the actual selected config. --- ldai/client.py | 69 ++++++++++++++++++++++++++----- ldai/testing/test_model_config.py | 63 +++++++++++++++++----------- pyproject.toml | 1 - release-please-config.json | 2 +- 4 files changed, 99 insertions(+), 36 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index e2d53a2..c5e7e6d 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -2,19 +2,26 @@ from typing import Any, Dict, List, Literal, Optional import chevron -from dataclasses_json import dataclass_json from ldclient import Context from ldclient.client import LDClient from ldai.tracker import LDAIConfigTracker -@dataclass_json @dataclass class LDMessage: role: Literal['system', 'user', 'assistant'] content: str + def to_dict(self) -> dict: + """ + Render the given message as a dictionary object. + """ + return { + 'role': self.role, + 'content': self.content, + } + class ModelConfig: """ @@ -62,6 +69,16 @@ def get_custom(self, key: str) -> Any: return self._custom.get(key) + def to_dict(self) -> dict: + """ + Render the given model config as a dictionary object. + """ + return { + 'id': self._id, + 'parameters': self._parameters, + 'custom': self._custom, + } + class ProviderConfig: """ @@ -78,14 +95,46 @@ def id(self) -> str: """ return self._id + def to_dict(self) -> dict: + """ + Render the given provider config as a dictionary object. + """ + return { + 'id': self._id, + } + + +@dataclass(frozen=True) +class DefaultAIConfig: + """ + The default values when evaluating an AI configuration. + """ + enabled: Optional[bool] = None + model: Optional[ModelConfig] = None + messages: Optional[List[LDMessage]] = None + provider: Optional[ProviderConfig] = None + + def to_dict(self) -> dict: + """ + Render the given default values as an AIConfig-compatible dictionary object. + """ + return { + '_ldMeta': { + 'enabled': self.enabled or False, + }, + 'model': self.model.to_dict() if self.model else None, + 'messages': [message.to_dict() for message in self.messages] if self.messages else None, + 'provider': self.provider.to_dict() if self.provider else None, + } + +@dataclass(frozen=True) class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None): - self.tracker = tracker - self.enabled = enabled - self.model = model - self.messages = messages - self.provider = provider + tracker: LDAIConfigTracker + enabled: bool + model: Optional[ModelConfig] = None + messages: Optional[List[LDMessage]] = None + provider: Optional[ProviderConfig] = None class LDAIClient: @@ -98,7 +147,7 @@ def config( self, key: str, context: Context, - default_value: AIConfig, + default_value: DefaultAIConfig, variables: Optional[Dict[str, Any]] = None, ) -> AIConfig: """ @@ -110,7 +159,7 @@ def config( :param variables: Additional variables for the model configuration. :return: The value of the model configuration. """ - variation = self.client.variation(key, context, default_value) + variation = self.client.variation(key, context, default_value.to_dict()) all_variables = {} if variables: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 6f97a4d..bfb235b 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -2,8 +2,7 @@ from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig -from ldai.tracker import LDAIConfigTracker +from ldai.client import DefaultAIConfig, LDAIClient, LDMessage, ModelConfig @pytest.fixture @@ -98,11 +97,6 @@ def client(td: TestData) -> LDClient: return LDClient(config=config) -@pytest.fixture -def tracker(client: LDClient) -> LDAIConfigTracker: - return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key')) - - @pytest.fixture def ldai_client(client: LDClient) -> LDAIClient: return LDAIClient(client) @@ -125,10 +119,31 @@ def test_model_config_handles_custom(): assert model.get_custom('id') is None -def test_model_config_interpolation(ldai_client: LDAIClient, tracker): +def test_uses_default_on_invalid_flag(ldai_client: LDAIClient): + context = Context.create('user-key') + default_value = DefaultAIConfig( + enabled=True, + model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}), + messages=[LDMessage(role='system', content='Hello, {{name}}!')], + ) + variables = {'name': 'World'} + + config = ldai_client.config('missing-flag', context, default_value, variables) + + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' + assert config.enabled is True + + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 + + +def test_model_config_interpolation(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig( - tracker=tracker, + default_value = DefaultAIConfig( enabled=True, model=ModelConfig('fakeModel'), messages=[LDMessage(role='system', content='Hello, {{name}}!')], @@ -148,9 +163,9 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 4096 -def test_model_config_no_variables(ldai_client: LDAIClient, tracker): +def test_model_config_no_variables(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.config('model-config', context, default_value, {}) @@ -165,9 +180,9 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 4096 -def test_provider_config_handling(ldai_client: LDAIClient, tracker): +def test_provider_config_handling(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} config = ldai_client.config('model-config', context, default_value, variables) @@ -176,9 +191,9 @@ def test_provider_config_handling(ldai_client: LDAIClient, tracker): assert config.provider.id == 'fakeProvider' -def test_context_interpolation(ldai_client: LDAIClient, tracker): +def test_context_interpolation(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} config = ldai_client.config( @@ -197,9 +212,9 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to' -def test_model_config_multiple(ldai_client: LDAIClient, tracker): +def test_model_config_multiple(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} config = ldai_client.config( @@ -218,9 +233,9 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') == 8192 -def test_model_config_disabled(ldai_client: LDAIClient, tracker): +def test_model_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.config('off-config', context, default_value, {}) @@ -231,9 +246,9 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): assert config.model.get_parameter('maxTokens') is None -def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): +def test_model_initial_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.config('initial-config-disabled', context, default_value, {}) @@ -243,9 +258,9 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): assert config.provider is None -def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): +def test_model_initial_config_enabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.config('initial-config-enabled', context, default_value, {}) diff --git a/pyproject.toml b/pyproject.toml index f301084..943049e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ exclude = [ python = ">=3.8,<4" launchdarkly-server-sdk = ">=9.4.0" chevron = "=0.14.0" -dataclasses-json = "^0.6.7" [tool.poetry.group.dev.dependencies] diff --git a/release-please-config.json b/release-please-config.json index a7ac352..78df6d7 100644 --- a/release-please-config.json +++ b/release-please-config.json @@ -3,7 +3,7 @@ ".": { "release-type": "python", "versioning": "default", - "release-as": "0.3.0", + "bump-minor-pre-major": true, "include-v-in-tag": false, "extra-files": ["ldai/__init__.py", "PROVENANCE.md"], "include-component-in-tag": false From b2eb9abf55c987cd199baf8d69cbf65410270dac Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Wed, 27 Nov 2024 10:05:26 -0500 Subject: [PATCH 2/2] Return config and tracker separately --- ldai/client.py | 39 ++++++++++++------------------- ldai/testing/test_model_config.py | 38 +++++++++++++++--------------- 2 files changed, 34 insertions(+), 43 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index c5e7e6d..1eceeac 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple import chevron from ldclient import Context @@ -105,10 +105,7 @@ def to_dict(self) -> dict: @dataclass(frozen=True) -class DefaultAIConfig: - """ - The default values when evaluating an AI configuration. - """ +class AIConfig: enabled: Optional[bool] = None model: Optional[ModelConfig] = None messages: Optional[List[LDMessage]] = None @@ -128,15 +125,6 @@ def to_dict(self) -> dict: } -@dataclass(frozen=True) -class AIConfig: - tracker: LDAIConfigTracker - enabled: bool - model: Optional[ModelConfig] = None - messages: Optional[List[LDMessage]] = None - provider: Optional[ProviderConfig] = None - - class LDAIClient: """The LaunchDarkly AI SDK client object.""" @@ -147,9 +135,9 @@ def config( self, key: str, context: Context, - default_value: DefaultAIConfig, + default_value: AIConfig, variables: Optional[Dict[str, Any]] = None, - ) -> AIConfig: + ) -> Tuple[AIConfig, LDAIConfigTracker]: """ Get the value of a model configuration. @@ -157,7 +145,7 @@ def config( :param context: The context to evaluate the model configuration in. :param default_value: The default value of the model configuration. :param variables: Additional variables for the model configuration. - :return: The value of the model configuration. + :return: The value of the model configuration along with a tracker used for gathering metrics. """ variation = self.client.variation(key, context, default_value.to_dict()) @@ -195,20 +183,23 @@ def config( custom=custom ) + tracker = LDAIConfigTracker( + self.client, + variation.get('_ldMeta', {}).get('versionKey', ''), + key, + context, + ) + enabled = variation.get('_ldMeta', {}).get('enabled', False) - return AIConfig( - tracker=LDAIConfigTracker( - self.client, - variation.get('_ldMeta', {}).get('versionKey', ''), - key, - context, - ), + config = AIConfig( enabled=bool(enabled), model=model, messages=messages, provider=provider_config, ) + return config, tracker + def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: """ Interpolate the template with the given variables. diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index bfb235b..593a037 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -2,7 +2,7 @@ from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldai.client import DefaultAIConfig, LDAIClient, LDMessage, ModelConfig +from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig @pytest.fixture @@ -121,14 +121,14 @@ def test_model_config_handles_custom(): def test_uses_default_on_invalid_flag(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig( + default_value = AIConfig( enabled=True, model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}), messages=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} - config = ldai_client.config('missing-flag', context, default_value, variables) + config, _ = ldai_client.config('missing-flag', context, default_value, variables) assert config.messages is not None assert len(config.messages) > 0 @@ -143,14 +143,14 @@ def test_uses_default_on_invalid_flag(ldai_client: LDAIClient): def test_model_config_interpolation(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig( + default_value = AIConfig( enabled=True, model=ModelConfig('fakeModel'), messages=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} - config = ldai_client.config('model-config', context, default_value, variables) + config, _ = ldai_client.config('model-config', context, default_value, variables) assert config.messages is not None assert len(config.messages) > 0 @@ -165,9 +165,9 @@ def test_model_config_interpolation(ldai_client: LDAIClient): def test_model_config_no_variables(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('model-config', context, default_value, {}) + config, _ = ldai_client.config('model-config', context, default_value, {}) assert config.messages is not None assert len(config.messages) > 0 @@ -182,10 +182,10 @@ def test_model_config_no_variables(ldai_client: LDAIClient): def test_provider_config_handling(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.config('model-config', context, default_value, variables) + config, _ = ldai_client.config('model-config', context, default_value, variables) assert config.provider is not None assert config.provider.id == 'fakeProvider' @@ -193,10 +193,10 @@ def test_provider_config_handling(ldai_client: LDAIClient): def test_context_interpolation(ldai_client: LDAIClient): context = Context.builder('user-key').name("Sandy").build() - default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.config( + config, _ = ldai_client.config( 'ctx-interpolation', context, default_value, variables ) @@ -214,10 +214,10 @@ def test_context_interpolation(ldai_client: LDAIClient): def test_model_config_multiple(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} - config = ldai_client.config( + config, _ = ldai_client.config( 'multiple-messages', context, default_value, variables ) @@ -235,9 +235,9 @@ def test_model_config_multiple(ldai_client: LDAIClient): def test_model_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('off-config', context, default_value, {}) + config, _ = ldai_client.config('off-config', context, default_value, {}) assert config.model is not None assert config.enabled is False @@ -248,9 +248,9 @@ def test_model_config_disabled(ldai_client: LDAIClient): def test_model_initial_config_disabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('initial-config-disabled', context, default_value, {}) + config, _ = ldai_client.config('initial-config-disabled', context, default_value, {}) assert config.enabled is False assert config.model is None @@ -260,9 +260,9 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient): def test_model_initial_config_enabled(ldai_client: LDAIClient): context = Context.create('user-key') - default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) + default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.config('initial-config-enabled', context, default_value, {}) + config, _ = ldai_client.config('initial-config-enabled', context, default_value, {}) assert config.enabled is True assert config.model is None