Skip to content

Commit 7c02cf0

Browse files
author
Daniel OBrien
committed
fixing up types
1 parent d97db5b commit 7c02cf0

File tree

4 files changed

+59
-56
lines changed

4 files changed

+59
-56
lines changed

ldai/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def model_config(self, key: str, context: Context, default_value: AIConfig, vari
2929
all_variables.update(variables)
3030
all_variables['ldctx'] = context
3131
print(variation)
32-
#if isinstance(variation['prompt'], list) and all(isinstance(entry, dict) for entry in variation['prompt']):
33-
variation['prompt'] = [
34-
{
35-
'role': entry['role'],
36-
'content': self.interpolate_template(entry['content'], all_variables)
37-
}
38-
for entry in variation['prompt']
39-
]
32+
if isinstance(variation['prompt'], list) and all(isinstance(entry, dict) for entry in variation['prompt']):
33+
variation['prompt'] = [
34+
{
35+
'role': entry['role'],
36+
'content': self.interpolate_template(entry['content'], all_variables)
37+
}
38+
for entry in variation['prompt']
39+
]
4040

4141
enabled = variation.get('_ldMeta',{}).get('enabled', False)
4242
return AIConfig(config=variation, tracker=LDAIConfigTracker(self.client, variation.get('_ldMeta', {}).get('versionKey', ''), key, context), enabled=bool(enabled))

ldai/testing/test_model_config.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import pytest
22
from ldclient import LDClient, Context, Config
33
from ldclient.integrations.test_data import TestData
4-
from ldai.types import AIConfig
4+
from ldai.types import AIConfig, AIConfigData, LDMessage
55
from ldai.client import LDAIClient
6+
from ldai.tracker import LDAIConfigTracker
67
from ldclient.testing.builders import *
78

9+
10+
811
@pytest.fixture
912
def td() -> TestData:
1013
td = TestData.data_source()
@@ -45,57 +48,56 @@ def ldai_client(client: LDClient) -> LDAIClient:
4548

4649
def test_model_config_interpolation(ldai_client: LDAIClient):
4750
context = Context.create('user-key')
48-
default_value = AIConfig(config={
49-
'model': { 'modelId': 'fakeModel'},
50-
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
51-
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'}
52-
}, tracker=None, enabled=True)
51+
default_value = AIConfig(config=AIConfigData(model={ 'modelId': 'fakeModel'}, prompt=[LDMessage(role='system', content='Hello, {{name}}!')], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
5352
variables = {'name': 'World'}
5453

5554
config = ldai_client.model_config('model-config', context, default_value, variables)
56-
57-
assert config.config['prompt'][0]['content'] == 'Hello, World!'
55+
56+
assert config.config.prompt is not None
57+
assert len(config.config.prompt) > 0
58+
assert config.config.prompt[0].content == 'Hello, World!'
5859
assert config.enabled is True
59-
assert config.tracker.version_key == 'abcd'
6060

6161
def test_model_config_no_variables(ldai_client: LDAIClient):
6262
context = Context.create('user-key')
63-
default_value = AIConfig(config={}, tracker=None, enabled=True)
63+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
6464

6565
config = ldai_client.model_config('model-config', context, default_value, {})
6666

67-
assert config.config['prompt'][0]['content'] == 'Hello, !'
67+
assert config.config.prompt is not None
68+
assert len(config.config.prompt) > 0
69+
assert config.config.prompt[0].content == 'Hello, !'
6870
assert config.enabled is True
69-
assert config.tracker.version_key == 'abcd'
7071

7172
def test_context_interpolation(ldai_client: LDAIClient):
7273
context = Context.builder('user-key').name("Sandy").build()
73-
default_value = AIConfig(config={}, tracker=None, enabled=True)
74+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
7475
variables = {'name': 'World'}
7576

7677
config = ldai_client.model_config('ctx-interpolation', context, default_value, variables)
7778

78-
assert config.config['prompt'][0]['content'] == 'Hello, Sandy!'
79+
assert config.config.prompt is not None
80+
assert len(config.config.prompt) > 0
81+
assert config.config.prompt[0].content == 'Hello, Sandy!'
7982
assert config.enabled is True
80-
assert config.tracker.version_key == 'abcd'
81-
83+
8284
def test_model_config_disabled(ldai_client: LDAIClient):
8385
context = Context.create('user-key')
84-
default_value = AIConfig(config={}, tracker=None, enabled=True)
86+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
8587

8688
config = ldai_client.model_config('off-config', context, default_value, {})
8789

8890
assert config.enabled is False
89-
assert config.tracker.version_key == 'abcd'
9091

9192
def test_model_config_multiple(ldai_client: LDAIClient):
9293
context = Context.create('user-key')
93-
default_value = AIConfig(config={}, tracker=None, enabled=True)
94+
default_value = AIConfig(config=AIConfigData(model={}, prompt=[], _ldMeta={'enabled': True, 'versionKey': 'abcd'}), tracker=LDAIConfigTracker(), enabled=True)
9495
variables = {'name': 'World', 'day': 'Monday'}
9596

9697
config = ldai_client.model_config('multiple-prompt', context, default_value, variables)
9798

98-
assert config.config['prompt'][0]['content'] == 'Hello, World!'
99-
assert config.config['prompt'][1]['content'] == 'The day is, Monday!'
100-
assert config.enabled is True
101-
assert config.tracker.version_key == 'abcd'
99+
assert config.config.prompt is not None
100+
assert len(config.config.prompt) > 0
101+
assert config.config.prompt[0].content == 'Hello, World!'
102+
assert config.config.prompt[1].content == 'The day is, Monday!'
103+
assert config.enabled is True

ldai/tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def track_bedrock_converse(self, res: dict) -> dict:
5858
def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None:
5959
token_metrics = tokens.to_metrics()
6060
if token_metrics['total'] > 0:
61-
self.ld_client.track('$ld:ai:tokens:total', self.context, self.get_track_data(), token_metrics['total'])
61+
self.ld_client.track('$ld:ai:tokens:total', self.context, self.get_track_data(), token_metrics.total)
6262
if token_metrics['input'] > 0:
63-
self.ld_client.track('$ld:ai:tokens:input', self.context, self.get_track_data(), token_metrics['input'])
63+
self.ld_client.track('$ld:ai:tokens:input', self.context, self.get_track_data(), token_metrics.input)
6464
if token_metrics['output'] > 0:
65-
self.ld_client.track('$ld:ai:tokens:output', self.context, self.get_track_data(), token_metrics['output'])
65+
self.ld_client.track('$ld:ai:tokens:output', self.context, self.get_track_data(), token_metrics.output)

ldai/types.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from enum import Enum
2-
from typing import Callable, List, Literal
2+
from typing import Any, Callable, List, Literal, Optional
33
from dataclasses import dataclass
44

5+
from ldai.tracker import LDAIConfigTracker
6+
57
@dataclass
68
class TokenMetrics():
79
total: int
@@ -15,19 +17,12 @@ class LDMessage():
1517

1618
@dataclass
1719
class AIConfigData():
18-
model: dict
19-
prompt: List[LDMessage]
20+
model: Optional[dict]
21+
prompt: Optional[List[LDMessage]]
2022
_ldMeta: dict
2123

22-
class AITracker():
23-
track_duration: Callable[..., None]
24-
track_tokens: Callable[..., None]
25-
track_error: Callable[..., None]
26-
track_generation: Callable[..., None]
27-
track_feedback: Callable[..., None]
28-
2924
class AIConfig():
30-
def __init__(self, config: AIConfigData, tracker: AITracker, enabled: bool):
25+
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
3126
self.config = config
3227
self.tracker = tracker
3328
self.enabled = enabled
@@ -50,19 +45,25 @@ def to_metrics(self):
5045
'output': self['completion_tokens'],
5146
}
5247

48+
@dataclass
49+
class LDOpenAIUsage():
50+
total_tokens: int
51+
prompt_tokens: int
52+
completion_tokens: int
53+
5354
@dataclass
5455
class OpenAITokenUsage:
55-
def __init__(self, data: any):
56+
def __init__(self, data: LDOpenAIUsage):
5657
self.total_tokens = data.total_tokens
5758
self.prompt_tokens = data.prompt_tokens
5859
self.completion_tokens = data.completion_tokens
5960

6061
def to_metrics(self) -> TokenMetrics:
61-
return {
62-
'total': self.total_tokens,
63-
'input': self.prompt_tokens,
64-
'output': self.completion_tokens,
65-
}
62+
return TokenMetrics(
63+
total=self.total_tokens,
64+
input=self.prompt_tokens,
65+
output=self.completion_tokens,
66+
)
6667

6768
@dataclass
6869
class BedrockTokenUsage:
@@ -72,8 +73,8 @@ def __init__(self, data: dict):
7273
self.outputTokens = data.get('outputTokens', 0)
7374

7475
def to_metrics(self) -> TokenMetrics:
75-
return {
76-
'total': self.totalTokens,
77-
'input': self.inputTokens,
78-
'output': self.outputTokens,
79-
}
76+
return TokenMetrics(
77+
total=self.totalTokens,
78+
input=self.inputTokens,
79+
output=self.outputTokens,
80+
)

0 commit comments

Comments
 (0)