Skip to content

Commit d1257fc

Browse files
author
Daniel OBrien
committed
fix tests
1 parent 5cae616 commit d1257fc

File tree

4 files changed

+100
-103
lines changed

4 files changed

+100
-103
lines changed

ldai/client.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
1-
from typing import Any, Dict, Optional
1+
from typing import Any, Dict, List, Literal, Optional
22
from ldclient import Context
33
from ldclient.client import LDClient
44
import chevron
55

66
from ldai.tracker import LDAIConfigTracker
7-
from ldai.types import AIConfig
7+
from dataclasses import dataclass
8+
9+
@dataclass
10+
class LDMessage():
11+
role: Literal['system', 'user', 'assistant']
12+
content: str
13+
14+
@dataclass
15+
class AIConfigData():
16+
model: Optional[dict]
17+
prompt: Optional[List[LDMessage]]
18+
class AIConfig():
19+
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
20+
self.config = config
21+
self.tracker = tracker
22+
self.enabled = enabled
823

924
class LDAIClient:
1025
"""The LaunchDarkly AI SDK client object."""
@@ -31,15 +46,15 @@ def model_config(self, key: str, context: Context, default_value: AIConfig, vari
3146
print(variation)
3247
if isinstance(variation['prompt'], list) and all(isinstance(entry, dict) for entry in variation['prompt']):
3348
variation['prompt'] = [
34-
{
35-
'role': entry['role'],
36-
'content': self.interpolate_template(entry['content'], all_variables)
37-
}
49+
LDMessage(
50+
role=entry['role'],
51+
content=self.interpolate_template(entry['content'], all_variables)
52+
)
3853
for entry in variation['prompt']
3954
]
4055

4156
enabled = variation.get('_ldMeta',{}).get('enabled', False)
42-
return AIConfig(config=variation, tracker=LDAIConfigTracker(self.client, variation.get('_ldMeta', {}).get('versionKey', ''), key, context), enabled=bool(enabled))
57+
return AIConfig(config=AIConfigData(model=variation['model'], prompt=variation['prompt']), tracker=LDAIConfigTracker(self.client, variation.get('_ldMeta', {}).get('versionKey', ''), key, context), enabled=bool(enabled))
4358

4459
def interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
4560
"""

ldai/testing/test_model_config.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import pytest
22
from ldclient import LDClient, Context, Config
33
from ldclient.integrations.test_data import TestData
4-
from ldai.types import AIConfig, AIConfigData, LDMessage
5-
from ldai.client import LDAIClient
4+
from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
65
from ldai.tracker import LDAIConfigTracker
76
from ldclient.testing.builders import *
87

@@ -41,13 +40,17 @@ def client(td: TestData) -> LDClient:
4140
config = Config('sdk-key', update_processor_class=td, send_events=False)
4241
return LDClient(config=config)
4342

43+
@pytest.fixture
44+
def tracker(td: TestData) -> LDAIConfigTracker:
45+
return LDAIConfigTracker(client(td), 'abcd', 'model-config', Context.create('user-key'))
46+
4447
@pytest.fixture
4548
def ldai_client(client: LDClient) -> LDAIClient:
4649
return LDAIClient(client)
4750

4851
def test_model_config_interpolation(ldai_client: LDAIClient):
4952
context = Context.create('user-key')
50-
default_value = AIConfig(config=AIConfigData(model={ 'modelId': 'fakeModel'}, prompt=[LDMessage(role='system', content='Hello, {{name}}!')], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
53+
default_value = AIConfig(config=AIConfigData(model={ 'modelId': 'fakeModel'}, prompt=[LDMessage(role='system', content='Hello, {{name}}!')]), tracker=tracker(td()), enabled=True)
5154
variables = {'name': 'World'}
5255

5356
config = ldai_client.model_config('model-config', context, default_value, variables)
@@ -59,7 +62,7 @@ def test_model_config_interpolation(ldai_client: LDAIClient):
5962

6063
def test_model_config_no_variables(ldai_client: LDAIClient):
6164
context = Context.create('user-key')
62-
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
65+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker(td()), enabled=True)
6366

6467
config = ldai_client.model_config('model-config', context, default_value, {})
6568

@@ -70,7 +73,7 @@ def test_model_config_no_variables(ldai_client: LDAIClient):
7073

7174
def test_context_interpolation(ldai_client: LDAIClient):
7275
context = Context.builder('user-key').name("Sandy").build()
73-
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
76+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker(td()), enabled=True)
7477
variables = {'name': 'World'}
7578

7679
config = ldai_client.model_config('ctx-interpolation', context, default_value, variables)
@@ -80,17 +83,9 @@ def test_context_interpolation(ldai_client: LDAIClient):
8083
assert config.config.prompt[0].content == 'Hello, Sandy!'
8184
assert config.enabled is True
8285

83-
def test_model_config_disabled(ldai_client: LDAIClient):
84-
context = Context.create('user-key')
85-
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
86-
87-
config = ldai_client.model_config('off-config', context, default_value, {})
88-
89-
assert config.enabled is False
90-
9186
def test_model_config_multiple(ldai_client: LDAIClient):
9287
context = Context.create('user-key')
93-
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
88+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker(td()), enabled=True)
9489
variables = {'name': 'World', 'day': 'Monday'}
9590

9691
config = ldai_client.model_config('multiple-prompt', context, default_value, variables)
@@ -99,4 +94,12 @@ def test_model_config_multiple(ldai_client: LDAIClient):
9994
assert len(config.config.prompt) > 0
10095
assert config.config.prompt[0].content == 'Hello, World!'
10196
assert config.config.prompt[1].content == 'The day is, Monday!'
102-
assert config.enabled is True
97+
assert config.enabled is True
98+
99+
def test_model_config_disabled(ldai_client: LDAIClient):
100+
context = Context.create('user-key')
101+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker(td()), enabled=False)
102+
103+
config = ldai_client.model_config('off-config', context, default_value, {})
104+
105+
assert config.enabled is False

ldai/tracker.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,67 @@
1+
from enum import Enum
12
import time
23
from typing import Dict, Union
34
from ldclient import Context, LDClient
4-
from ldai.types import BedrockTokenUsage, FeedbackKind, OpenAITokenUsage, TokenUsage
5+
from dataclasses import dataclass
56

7+
@dataclass
8+
class TokenMetrics():
9+
total: int
10+
input: int
11+
output: int # type: ignore
12+
13+
@dataclass
14+
class FeedbackKind(Enum):
15+
Positive = "positive"
16+
Negative = "negative"
17+
18+
@dataclass
19+
class TokenUsage():
20+
total_tokens: int
21+
prompt_tokens: int
22+
completion_tokens: int
23+
24+
def to_metrics(self):
25+
return {
26+
'total': self['total_tokens'],
27+
'input': self['prompt_tokens'],
28+
'output': self['completion_tokens'],
29+
}
30+
31+
@dataclass
32+
class LDOpenAIUsage():
33+
total_tokens: int
34+
prompt_tokens: int
35+
completion_tokens: int
36+
37+
@dataclass
38+
class OpenAITokenUsage:
39+
def __init__(self, data: LDOpenAIUsage):
40+
self.total_tokens = data.total_tokens
41+
self.prompt_tokens = data.prompt_tokens
42+
self.completion_tokens = data.completion_tokens
43+
44+
def to_metrics(self) -> TokenMetrics:
45+
return TokenMetrics(
46+
total=self.total_tokens,
47+
input=self.prompt_tokens,
48+
output=self.completion_tokens,
49+
)
50+
51+
@dataclass
52+
class BedrockTokenUsage:
53+
def __init__(self, data: dict):
54+
self.totalTokens = data.get('totalTokens', 0)
55+
self.inputTokens = data.get('inputTokens', 0)
56+
self.outputTokens = data.get('outputTokens', 0)
57+
58+
def to_metrics(self) -> TokenMetrics:
59+
return TokenMetrics(
60+
total=self.totalTokens,
61+
input=self.inputTokens,
62+
output=self.outputTokens,
63+
)
64+
665
class LDAIConfigTracker:
766
def __init__(self, ld_client: LDClient, version_key: str, config_key: str, context: Context):
867
self.ld_client = ld_client

ldai/types.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)