Skip to content

Commit 6b4f8c1

Browse files
committed
Expand tests and assertions for new model config
1 parent dcbd88f commit 6b4f8c1

File tree

2 files changed

+112
-13
lines changed

2 files changed

+112
-13
lines changed

ldai/client.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,65 @@ class LDMessage:
1616
content: str
1717

1818

19+
class ModelConfig:
20+
"""
21+
Configuration related to the model.
22+
"""
23+
24+
def __init__(self, id: str, temperature: Optional[float] = None,
25+
max_tokens: Optional[int] = None, attributes: dict = {}):
26+
"""
27+
:param id: The ID of the model.
28+
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
29+
:param max_tokens: The maximum number of tokens.
30+
:param attributes: Additional model-specific attributes.
31+
"""
32+
self._id = id
33+
self._temperature = temperature
34+
self._max_tokens = max_tokens
35+
self._attributes = attributes
36+
37+
@property
38+
def id(self):
39+
"""
40+
The ID of the model.
41+
"""
42+
return self._id
43+
44+
@property
45+
def temperature(self):
46+
""""
47+
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
48+
"""
49+
return self._temperature
50+
51+
@property
52+
def max_tokens(self):
53+
"""
54+
The maximum number of tokens.
55+
"""
56+
57+
return self._max_tokens
58+
59+
def get_attribute(self, key: str):
60+
"""
61+
Retrieve model-specific attributes.
62+
63+
Accessing a named, typed attribute (e.g. id) will result in the call
64+
being delegated to the appropriate property.
65+
"""
66+
if key == 'id':
67+
return self.id
68+
if key == 'temperature':
69+
return self.temperature
70+
if key == 'maxTokens':
71+
return self.max_tokens
72+
73+
return self._attributes.get(key)
74+
75+
1976
class AIConfig:
20-
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[dict], prompt: Optional[List[LDMessage]]):
77+
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]):
2178
self.tracker = tracker
2279
self.enabled = enabled
2380
self.model = model
@@ -66,6 +123,15 @@ def model_config(
66123
for entry in variation['prompt']
67124
]
68125

126+
model = None
127+
if 'model' in variation:
128+
model = ModelConfig(
129+
id=variation['model']['modelId'],
130+
temperature=variation['model'].get('temperature'),
131+
max_tokens=variation['model'].get('maxTokens'),
132+
attributes=variation['model'],
133+
)
134+
69135
enabled = variation.get('_ldMeta', {}).get('enabled', False)
70136
return AIConfig(
71137
tracker=LDAIConfigTracker(
@@ -75,7 +141,7 @@ def model_config(
75141
context,
76142
),
77143
enabled=bool(enabled),
78-
model=variation['model'],
144+
model=model,
79145
prompt=prompt
80146
)
81147

ldai/testing/test_model_config.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytest
22
from ldclient import Config, Context, LDClient
33
from ldclient.integrations.test_data import TestData
4-
from ldclient.testing.builders import *
54

6-
from ldai.client import AIConfig, LDAIClient, LDMessage
5+
from ldai.client import AIConfig, ModelConfig, LDAIClient, LDMessage
76
from ldai.tracker import LDAIConfigTracker
87

98

@@ -14,7 +13,7 @@ def td() -> TestData:
1413
td.flag('model-config')
1514
.variations(
1615
{
17-
'model': {'modelId': 'fakeModel'},
16+
'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096},
1817
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
1918
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
2019
},
@@ -27,7 +26,7 @@ def td() -> TestData:
2726
td.flag('multiple-prompt')
2827
.variations(
2928
{
30-
'model': {'modelId': 'fakeModel'},
29+
'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192},
3130
'prompt': [
3231
{'role': 'system', 'content': 'Hello, {{name}}!'},
3332
{'role': 'user', 'content': 'The day is, {{day}}!'},
@@ -43,7 +42,7 @@ def td() -> TestData:
4342
td.flag('ctx-interpolation')
4443
.variations(
4544
{
46-
'model': {'modelId': 'fakeModel'},
45+
'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'},
4746
'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}],
4847
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
4948
}
@@ -55,7 +54,7 @@ def td() -> TestData:
5554
td.flag('off-config')
5655
.variations(
5756
{
58-
'model': {'modelId': 'fakeModel'},
57+
'model': {'modelId': 'fakeModel', 'temperature': 0.1},
5958
'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
6059
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
6160
}
@@ -82,12 +81,26 @@ def ldai_client(client: LDClient) -> LDAIClient:
8281
return LDAIClient(client)
8382

8483

84+
def test_model_config_delegates_to_properties():
85+
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
86+
assert model.id == 'fakeModel'
87+
assert model.temperature == 0.5
88+
assert model.max_tokens == 4096
89+
assert model.get_attribute('extra-attribute') == 'value'
90+
assert model.get_attribute('non-existent') is None
91+
92+
assert model.id == model.get_attribute('id')
93+
assert model.temperature == model.get_attribute('temperature')
94+
assert model.max_tokens == model.get_attribute('maxTokens')
95+
assert model.max_tokens != model.get_attribute('max_tokens')
96+
97+
8598
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
8699
context = Context.create('user-key')
87100
default_value = AIConfig(
88101
tracker=tracker,
89102
enabled=True,
90-
model={'modelId': 'fakeModel'},
103+
model=ModelConfig('fakeModel'),
91104
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
92105
)
93106
variables = {'name': 'World'}
@@ -99,10 +112,14 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
99112
assert config.prompt[0].content == 'Hello, World!'
100113
assert config.enabled is True
101114

115+
assert config.model.id == 'fakeModel'
116+
assert config.model.temperature == 0.5
117+
assert config.model.max_tokens == 4096
118+
102119

103120
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
104121
context = Context.create('user-key')
105-
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
122+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
106123

107124
config = ldai_client.model_config('model-config', context, default_value, {})
108125

@@ -111,10 +128,14 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
111128
assert config.prompt[0].content == 'Hello, !'
112129
assert config.enabled is True
113130

131+
assert config.model.id == 'fakeModel'
132+
assert config.model.temperature == 0.5
133+
assert config.model.max_tokens == 4096
134+
114135

115136
def test_context_interpolation(ldai_client: LDAIClient, tracker):
116137
context = Context.builder('user-key').name("Sandy").build()
117-
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
138+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
118139
variables = {'name': 'World'}
119140

120141
config = ldai_client.model_config(
@@ -126,10 +147,15 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):
126147
assert config.prompt[0].content == 'Hello, Sandy!'
127148
assert config.enabled is True
128149

150+
assert config.model.id == 'fakeModel'
151+
assert config.model.temperature is None
152+
assert config.model.max_tokens is None
153+
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'
154+
129155

130156
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
131157
context = Context.create('user-key')
132-
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
158+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[])
133159
variables = {'name': 'World', 'day': 'Monday'}
134160

135161
config = ldai_client.model_config(
@@ -142,11 +168,18 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):
142168
assert config.prompt[1].content == 'The day is, Monday!'
143169
assert config.enabled is True
144170

171+
assert config.model.id == 'fakeModel'
172+
assert config.model.temperature == 0.7
173+
assert config.model.max_tokens == 8192
174+
145175

146176
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
147177
context = Context.create('user-key')
148-
default_value = AIConfig(tracker=tracker, enabled=False, model={}, prompt=[])
178+
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[])
149179

150180
config = ldai_client.model_config('off-config', context, default_value, {})
151181

152182
assert config.enabled is False
183+
assert config.model.id == 'fakeModel'
184+
assert config.model.temperature == 0.1
185+
assert config.model.max_tokens is None

0 commit comments

Comments
 (0)