Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions ldai/client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, Tuple

import chevron
from dataclasses_json import dataclass_json
from ldclient import Context
from ldclient.client import LDClient

from ldai.tracker import LDAIConfigTracker


@dataclass_json
@dataclass
class LDMessage:
role: Literal['system', 'user', 'assistant']
content: str

def to_dict(self) -> dict:
"""
Render the given message as a dictionary object.
"""
return {
'role': self.role,
'content': self.content,
}


class ModelConfig:
"""
Expand Down Expand Up @@ -62,6 +69,16 @@ def get_custom(self, key: str) -> Any:

return self._custom.get(key)

def to_dict(self) -> dict:
"""
Render the given model config as a dictionary object.
"""
return {
'id': self._id,
'parameters': self._parameters,
'custom': self._custom,
}


class ProviderConfig:
"""
Expand All @@ -78,14 +95,34 @@ def id(self) -> str:
"""
return self._id

def to_dict(self) -> dict:
"""
Render the given provider config as a dictionary object.
"""
return {
'id': self._id,
}


@dataclass(frozen=True)
class AIConfig:
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None):
self.tracker = tracker
self.enabled = enabled
self.model = model
self.messages = messages
self.provider = provider
enabled: Optional[bool] = None
model: Optional[ModelConfig] = None
messages: Optional[List[LDMessage]] = None
provider: Optional[ProviderConfig] = None

def to_dict(self) -> dict:
"""
Render the given default values as an AIConfig-compatible dictionary object.
"""
return {
'_ldMeta': {
'enabled': self.enabled or False,
},
'model': self.model.to_dict() if self.model else None,
'messages': [message.to_dict() for message in self.messages] if self.messages else None,
'provider': self.provider.to_dict() if self.provider else None,
}


class LDAIClient:
Expand All @@ -100,17 +137,17 @@ def config(
context: Context,
default_value: AIConfig,
variables: Optional[Dict[str, Any]] = None,
) -> AIConfig:
) -> Tuple[AIConfig, LDAIConfigTracker]:
"""
Get the value of a model configuration.

:param key: The key of the model configuration.
:param context: The context to evaluate the model configuration in.
:param default_value: The default value of the model configuration.
:param variables: Additional variables for the model configuration.
:return: The value of the model configuration.
:return: The value of the model configuration along with a tracker used for gathering metrics.
"""
variation = self.client.variation(key, context, default_value)
variation = self.client.variation(key, context, default_value.to_dict())

all_variables = {}
if variables:
Expand Down Expand Up @@ -146,20 +183,23 @@ def config(
custom=custom
)

tracker = LDAIConfigTracker(
self.client,
variation.get('_ldMeta', {}).get('versionKey', ''),
key,
context,
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
return AIConfig(
tracker=LDAIConfigTracker(
self.client,
variation.get('_ldMeta', {}).get('versionKey', ''),
key,
context,
),
config = AIConfig(
enabled=bool(enabled),
model=model,
messages=messages,
provider=provider_config,
)

return config, tracker

def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
"""
Interpolate the template with the given variables.
Expand Down
75 changes: 45 additions & 30 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ldclient.integrations.test_data import TestData

from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
from ldai.tracker import LDAIConfigTracker


@pytest.fixture
Expand Down Expand Up @@ -98,11 +97,6 @@ def client(td: TestData) -> LDClient:
return LDClient(config=config)


@pytest.fixture
def tracker(client: LDClient) -> LDAIConfigTracker:
return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key'))


@pytest.fixture
def ldai_client(client: LDClient) -> LDAIClient:
return LDAIClient(client)
Expand All @@ -125,17 +119,38 @@ def test_model_config_handles_custom():
assert model.get_custom('id') is None


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(
enabled=True,
model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}),
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
)
variables = {'name': 'World'}

config, _ = ldai_client.config('missing-flag', context, default_value, variables)

assert config.messages is not None
assert len(config.messages) > 0
assert config.messages[0].content == 'Hello, World!'
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096


def test_model_config_interpolation(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(
tracker=tracker,
enabled=True,
model=ModelConfig('fakeModel'),
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
)
variables = {'name': 'World'}

config = ldai_client.config('model-config', context, default_value, variables)
config, _ = ldai_client.config('model-config', context, default_value, variables)

assert config.messages is not None
assert len(config.messages) > 0
Expand All @@ -148,11 +163,11 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
assert config.model.get_parameter('maxTokens') == 4096


def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
def test_model_config_no_variables(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.config('model-config', context, default_value, {})
config, _ = ldai_client.config('model-config', context, default_value, {})

assert config.messages is not None
assert len(config.messages) > 0
Expand All @@ -165,23 +180,23 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
assert config.model.get_parameter('maxTokens') == 4096


def test_provider_config_handling(ldai_client: LDAIClient, tracker):
def test_provider_config_handling(ldai_client: LDAIClient):
context = Context.builder('user-key').name("Sandy").build()
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
variables = {'name': 'World'}

config = ldai_client.config('model-config', context, default_value, variables)
config, _ = ldai_client.config('model-config', context, default_value, variables)

assert config.provider is not None
assert config.provider.id == 'fakeProvider'


def test_context_interpolation(ldai_client: LDAIClient, tracker):
def test_context_interpolation(ldai_client: LDAIClient):
context = Context.builder('user-key').name("Sandy").build()
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
variables = {'name': 'World'}

config = ldai_client.config(
config, _ = ldai_client.config(
'ctx-interpolation', context, default_value, variables
)

Expand All @@ -197,12 +212,12 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'


def test_model_config_multiple(ldai_client: LDAIClient, tracker):
def test_model_config_multiple(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
variables = {'name': 'World', 'day': 'Monday'}

config = ldai_client.config(
config, _ = ldai_client.config(
'multiple-messages', context, default_value, variables
)

Expand All @@ -218,11 +233,11 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):
assert config.model.get_parameter('maxTokens') == 8192


def test_model_config_disabled(ldai_client: LDAIClient, tracker):
def test_model_config_disabled(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.config('off-config', context, default_value, {})
config, _ = ldai_client.config('off-config', context, default_value, {})

assert config.model is not None
assert config.enabled is False
Expand All @@ -231,23 +246,23 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker):
assert config.model.get_parameter('maxTokens') is None


def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
def test_model_initial_config_disabled(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.config('initial-config-disabled', context, default_value, {})
config, _ = ldai_client.config('initial-config-disabled', context, default_value, {})

assert config.enabled is False
assert config.model is None
assert config.messages is None
assert config.provider is None


def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker):
def test_model_initial_config_enabled(ldai_client: LDAIClient):
context = Context.create('user-key')
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])

config = ldai_client.config('initial-config-enabled', context, default_value, {})
config, _ = ldai_client.config('initial-config-enabled', context, default_value, {})

assert config.enabled is True
assert config.model is None
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ exclude = [
python = ">=3.8,<4"
launchdarkly-server-sdk = ">=9.4.0"
chevron = "=0.14.0"
dataclasses-json = "^0.6.7"


[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion release-please-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
".": {
"release-type": "python",
"versioning": "default",
"release-as": "0.3.0",
"bump-minor-pre-major": true,
"include-v-in-tag": false,
"extra-files": ["ldai/__init__.py", "PROVENANCE.md"],
"include-component-in-tag": false
Expand Down
Loading