Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ test: install
lint: #! Run type analysis and linting checks
lint: install
@poetry run mypy ldai
@poetry run isort --check --atomic ldai
@poetry run pycodestyle ldai

#
# Documentation generation
Expand Down
49 changes: 37 additions & 12 deletions ldai/client.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,45 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional

import chevron
from ldclient import Context
from ldclient.client import LDClient
import chevron

from ldai.tracker import LDAIConfigTracker
from dataclasses import dataclass


@dataclass
class LDMessage():
class LDMessage:
role: Literal['system', 'user', 'assistant']
content: str


@dataclass
class AIConfigData():
class AIConfigData:
model: Optional[dict]
prompt: Optional[List[LDMessage]]
class AIConfig():


class AIConfig:
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
self.config = config
self.tracker = tracker
self.enabled = enabled


class LDAIClient:
"""The LaunchDarkly AI SDK client object."""

def __init__(self, client: LDClient):
self.client = client

def model_config(self, key: str, context: Context, default_value: AIConfig, variables: Optional[Dict[str, Any]] = None) -> AIConfig:
def model_config(
self,
key: str,
context: Context,
default_value: AIConfig,
variables: Optional[Dict[str, Any]] = None,
) -> AIConfig:
"""
Get the value of a model configuration asynchronously.

Expand All @@ -43,18 +55,31 @@ def model_config(self, key: str, context: Context, default_value: AIConfig, vari
if variables:
all_variables.update(variables)
all_variables['ldctx'] = context

if isinstance(variation['prompt'], list) and all(isinstance(entry, dict) for entry in variation['prompt']):

if isinstance(variation['prompt'], list) and all(
isinstance(entry, dict) for entry in variation['prompt']
):
variation['prompt'] = [
LDMessage(
role=entry['role'],
content=self.__interpolate_template(entry['content'], all_variables)
content=self.__interpolate_template(
entry['content'], all_variables
),
)
for entry in variation['prompt']
]

enabled = variation.get('_ldMeta',{}).get('enabled', False)
return AIConfig(config=AIConfigData(model=variation['model'], prompt=variation['prompt']), tracker=LDAIConfigTracker(self.client, variation.get('_ldMeta', {}).get('versionKey', ''), key, context), enabled=bool(enabled))
enabled = variation.get('_ldMeta', {}).get('enabled', False)
return AIConfig(
config=AIConfigData(model=variation['model'], prompt=variation['prompt']),
tracker=LDAIConfigTracker(
self.client,
variation.get('_ldMeta', {}).get('versionKey', ''),
key,
context,
),
enabled=bool(enabled),
)

def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
"""
Expand All @@ -64,4 +89,4 @@ def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> st
:variables: The variables to interpolate into the template.
:return: The interpolated string.
"""
return chevron.render(template, variables)
return chevron.render(template, variables)
121 changes: 89 additions & 32 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,97 @@
import pytest
from ldclient import LDClient, Context, Config
from ldclient import Config, Context, LDClient
from ldclient.integrations.test_data import TestData
from ldclient.testing.builders import *

from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
from ldai.tracker import LDAIConfigTracker
from ldclient.testing.builders import *


@pytest.fixture
def td() -> TestData:
td = TestData.data_source()
td.update(td.flag('model-config').variations({
'model': { 'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'}
}, "green").variation_for_all(0))

td.update(td.flag('multiple-prompt').variations({
'model': { 'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'}
}, "green").variation_for_all(0))

td.update(td.flag('ctx-interpolation').variations({
'model': { 'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'}
}).variation_for_all(0))

td.update(td.flag('off-config').variations({
'model': { 'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'}
}).variation_for_all(0))
td.update(
td.flag('model-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
},
"green",
)
.variation_for_all(0)
)

td.update(
td.flag('multiple-prompt')
.variations(
{
'model': {'modelId': 'fakeModel'},
'prompt': [
{'role': 'system', 'content': 'Hello, {{name}}!'},
{'role': 'user', 'content': 'The day is, {{day}}!'},
],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
},
"green",
)
.variation_for_all(0)
)

td.update(
td.flag('ctx-interpolation')
.variations(
{
'model': {'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
}
)
.variation_for_all(0)
)

td.update(
td.flag('off-config')
.variations(
{
'model': {'modelId': 'fakeModel'},
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
}
)
.variation_for_all(0)
)

return td


@pytest.fixture
def client(td: TestData) -> LDClient:
config = Config('sdk-key', update_processor_class=td, send_events=False)
return LDClient(config=config)


@pytest.fixture
def tracker(client: LDClient) -> LDAIConfigTracker:
return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key'))


@pytest.fixture
def ldai_client(client: LDClient) -> LDAIClient:
return LDAIClient(client)


def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(config=AIConfigData(model={ 'modelId': 'fakeModel'}, prompt=[LDMessage(role='system', content='Hello, {{name}}!')]), tracker=tracker, enabled=True)
default_value = AIConfig(
config=AIConfigData(
model={'modelId': 'fakeModel'},
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
),
tracker=tracker,
enabled=True,
)
variables = {'name': 'World'}

config = ldai_client.model_config('model-config', context, default_value, variables)
Expand All @@ -60,9 +101,12 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
assert config.config.prompt[0].content == 'Hello, World!'
assert config.enabled is True


def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True)
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)

config = ldai_client.model_config('model-config', context, default_value, {})

Expand All @@ -71,34 +115,47 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
assert config.config.prompt[0].content == 'Hello, !'
assert config.enabled is True


def test_context_interpolation(ldai_client: LDAIClient, tracker):
context = Context.builder('user-key').name("Sandy").build()
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True)
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
variables = {'name': 'World'}

config = ldai_client.model_config('ctx-interpolation', context, default_value, variables)
config = ldai_client.model_config(
'ctx-interpolation', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, Sandy!'
assert config.enabled is True


def test_model_config_multiple(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True)
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
)
variables = {'name': 'World', 'day': 'Monday'}

config = ldai_client.model_config('multiple-prompt', context, default_value, variables)
config = ldai_client.model_config(
'multiple-prompt', context, default_value, variables
)

assert config.config.prompt is not None
assert len(config.config.prompt) > 0
assert config.config.prompt[0].content == 'Hello, World!'
assert config.config.prompt[1].content == 'The day is, Monday!'
assert config.enabled is True


def test_model_config_disabled(ldai_client: LDAIClient, tracker):
context = Context.create('user-key')
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False)
default_value = AIConfig(
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False
)

config = ldai_client.model_config('off-config', context, default_value, {})

Expand Down
Loading