Skip to content

Commit 96f888f

Browse files
authored
fix: Return AIConfig and LDAITracker separately (#23)
Consumers of this SDK are responsible for providing some default value to the `config` method. If that type is of `AIConfig`, customers have the option of providing a tracker instance that will never be useful. To address this, we are going to remove the tracker from the `AIConfig` type entirely. Instead, we are going to return the `AIConfig` and the `LDAITracker` separately.
1 parent 8e2d409 commit 96f888f

File tree

4 files changed

+105
-51
lines changed

4 files changed

+105
-51
lines changed

ldai/client.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Literal, Optional
2+
from typing import Any, Dict, List, Literal, Optional, Tuple
33

44
import chevron
5-
from dataclasses_json import dataclass_json
65
from ldclient import Context
76
from ldclient.client import LDClient
87

98
from ldai.tracker import LDAIConfigTracker
109

1110

12-
@dataclass_json
1311
@dataclass
1412
class LDMessage:
1513
role: Literal['system', 'user', 'assistant']
1614
content: str
1715

16+
def to_dict(self) -> dict:
17+
"""
18+
Render the given message as a dictionary object.
19+
"""
20+
return {
21+
'role': self.role,
22+
'content': self.content,
23+
}
24+
1825

1926
class ModelConfig:
2027
"""
@@ -62,6 +69,16 @@ def get_custom(self, key: str) -> Any:
6269

6370
return self._custom.get(key)
6471

72+
def to_dict(self) -> dict:
73+
"""
74+
Render the given model config as a dictionary object.
75+
"""
76+
return {
77+
'id': self._id,
78+
'parameters': self._parameters,
79+
'custom': self._custom,
80+
}
81+
6582

6683
class ProviderConfig:
6784
"""
@@ -78,14 +95,34 @@ def id(self) -> str:
7895
"""
7996
return self._id
8097

98+
def to_dict(self) -> dict:
99+
"""
100+
Render the given provider config as a dictionary object.
101+
"""
102+
return {
103+
'id': self._id,
104+
}
81105

106+
107+
@dataclass(frozen=True)
82108
class AIConfig:
83-
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None):
84-
self.tracker = tracker
85-
self.enabled = enabled
86-
self.model = model
87-
self.messages = messages
88-
self.provider = provider
109+
enabled: Optional[bool] = None
110+
model: Optional[ModelConfig] = None
111+
messages: Optional[List[LDMessage]] = None
112+
provider: Optional[ProviderConfig] = None
113+
114+
def to_dict(self) -> dict:
115+
"""
116+
Render the given default values as an AIConfig-compatible dictionary object.
117+
"""
118+
return {
119+
'_ldMeta': {
120+
'enabled': self.enabled or False,
121+
},
122+
'model': self.model.to_dict() if self.model else None,
123+
'messages': [message.to_dict() for message in self.messages] if self.messages else None,
124+
'provider': self.provider.to_dict() if self.provider else None,
125+
}
89126

90127

91128
class LDAIClient:
@@ -100,17 +137,17 @@ def config(
100137
context: Context,
101138
default_value: AIConfig,
102139
variables: Optional[Dict[str, Any]] = None,
103-
) -> AIConfig:
140+
) -> Tuple[AIConfig, LDAIConfigTracker]:
104141
"""
105142
Get the value of a model configuration.
106143
107144
:param key: The key of the model configuration.
108145
:param context: The context to evaluate the model configuration in.
109146
:param default_value: The default value of the model configuration.
110147
:param variables: Additional variables for the model configuration.
111-
:return: The value of the model configuration.
148+
:return: The value of the model configuration along with a tracker used for gathering metrics.
112149
"""
113-
variation = self.client.variation(key, context, default_value)
150+
variation = self.client.variation(key, context, default_value.to_dict())
114151

115152
all_variables = {}
116153
if variables:
@@ -146,20 +183,23 @@ def config(
146183
custom=custom
147184
)
148185

186+
tracker = LDAIConfigTracker(
187+
self.client,
188+
variation.get('_ldMeta', {}).get('versionKey', ''),
189+
key,
190+
context,
191+
)
192+
149193
enabled = variation.get('_ldMeta', {}).get('enabled', False)
150-
return AIConfig(
151-
tracker=LDAIConfigTracker(
152-
self.client,
153-
variation.get('_ldMeta', {}).get('versionKey', ''),
154-
key,
155-
context,
156-
),
194+
config = AIConfig(
157195
enabled=bool(enabled),
158196
model=model,
159197
messages=messages,
160198
provider=provider_config,
161199
)
162200

201+
return config, tracker
202+
163203
def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
164204
"""
165205
Interpolate the template with the given variables.

ldai/testing/test_model_config.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from ldclient.integrations.test_data import TestData
44

55
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
6-
from ldai.tracker import LDAIConfigTracker
76

87

98
@pytest.fixture
@@ -98,11 +97,6 @@ def client(td: TestData) -> LDClient:
9897
return LDClient(config=config)
9998

10099

101-
@pytest.fixture
102-
def tracker(client: LDClient) -> LDAIConfigTracker:
103-
return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key'))
104-
105-
106100
@pytest.fixture
107101
def ldai_client(client: LDClient) -> LDAIClient:
108102
return LDAIClient(client)
@@ -125,17 +119,38 @@ def test_model_config_handles_custom():
125119
assert model.get_custom('id') is None
126120

127121

128-
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
122+
def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
123+
context = Context.create('user-key')
124+
default_value = AIConfig(
125+
enabled=True,
126+
model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}),
127+
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
128+
)
129+
variables = {'name': 'World'}
130+
131+
config, _ = ldai_client.config('missing-flag', context, default_value, variables)
132+
133+
assert config.messages is not None
134+
assert len(config.messages) > 0
135+
assert config.messages[0].content == 'Hello, World!'
136+
assert config.enabled is True
137+
138+
assert config.model is not None
139+
assert config.model.id == 'fakeModel'
140+
assert config.model.get_parameter('temperature') == 0.5
141+
assert config.model.get_parameter('maxTokens') == 4096
142+
143+
144+
def test_model_config_interpolation(ldai_client: LDAIClient):
129145
context = Context.create('user-key')
130146
default_value = AIConfig(
131-
tracker=tracker,
132147
enabled=True,
133148
model=ModelConfig('fakeModel'),
134149
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
135150
)
136151
variables = {'name': 'World'}
137152

138-
config = ldai_client.config('model-config', context, default_value, variables)
153+
config, _ = ldai_client.config('model-config', context, default_value, variables)
139154

140155
assert config.messages is not None
141156
assert len(config.messages) > 0
@@ -148,11 +163,11 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
148163
assert config.model.get_parameter('maxTokens') == 4096
149164

150165

151-
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
166+
def test_model_config_no_variables(ldai_client: LDAIClient):
152167
context = Context.create('user-key')
153-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
168+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
154169

155-
config = ldai_client.config('model-config', context, default_value, {})
170+
config, _ = ldai_client.config('model-config', context, default_value, {})
156171

157172
assert config.messages is not None
158173
assert len(config.messages) > 0
@@ -165,23 +180,23 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
165180
assert config.model.get_parameter('maxTokens') == 4096
166181

167182

168-
def test_provider_config_handling(ldai_client: LDAIClient, tracker):
183+
def test_provider_config_handling(ldai_client: LDAIClient):
169184
context = Context.builder('user-key').name("Sandy").build()
170-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
185+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
171186
variables = {'name': 'World'}
172187

173-
config = ldai_client.config('model-config', context, default_value, variables)
188+
config, _ = ldai_client.config('model-config', context, default_value, variables)
174189

175190
assert config.provider is not None
176191
assert config.provider.id == 'fakeProvider'
177192

178193

179-
def test_context_interpolation(ldai_client: LDAIClient, tracker):
194+
def test_context_interpolation(ldai_client: LDAIClient):
180195
context = Context.builder('user-key').name("Sandy").build()
181-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
196+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
182197
variables = {'name': 'World'}
183198

184-
config = ldai_client.config(
199+
config, _ = ldai_client.config(
185200
'ctx-interpolation', context, default_value, variables
186201
)
187202

@@ -197,12 +212,12 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):
197212
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'
198213

199214

200-
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
215+
def test_model_config_multiple(ldai_client: LDAIClient):
201216
context = Context.create('user-key')
202-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
217+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
203218
variables = {'name': 'World', 'day': 'Monday'}
204219

205-
config = ldai_client.config(
220+
config, _ = ldai_client.config(
206221
'multiple-messages', context, default_value, variables
207222
)
208223

@@ -218,11 +233,11 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):
218233
assert config.model.get_parameter('maxTokens') == 8192
219234

220235

221-
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
236+
def test_model_config_disabled(ldai_client: LDAIClient):
222237
context = Context.create('user-key')
223-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
238+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
224239

225-
config = ldai_client.config('off-config', context, default_value, {})
240+
config, _ = ldai_client.config('off-config', context, default_value, {})
226241

227242
assert config.model is not None
228243
assert config.enabled is False
@@ -231,23 +246,23 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker):
231246
assert config.model.get_parameter('maxTokens') is None
232247

233248

234-
def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
249+
def test_model_initial_config_disabled(ldai_client: LDAIClient):
235250
context = Context.create('user-key')
236-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
251+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
237252

238-
config = ldai_client.config('initial-config-disabled', context, default_value, {})
253+
config, _ = ldai_client.config('initial-config-disabled', context, default_value, {})
239254

240255
assert config.enabled is False
241256
assert config.model is None
242257
assert config.messages is None
243258
assert config.provider is None
244259

245260

246-
def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker):
261+
def test_model_initial_config_enabled(ldai_client: LDAIClient):
247262
context = Context.create('user-key')
248-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
263+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
249264

250-
config = ldai_client.config('initial-config-enabled', context, default_value, {})
265+
config, _ = ldai_client.config('initial-config-enabled', context, default_value, {})
251266

252267
assert config.enabled is True
253268
assert config.model is None

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ exclude = [
3030
python = ">=3.8,<4"
3131
launchdarkly-server-sdk = ">=9.4.0"
3232
chevron = "=0.14.0"
33-
dataclasses-json = "^0.6.7"
3433

3534

3635
[tool.poetry.group.dev.dependencies]

release-please-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
".": {
44
"release-type": "python",
55
"versioning": "default",
6-
"release-as": "0.3.0",
6+
"bump-minor-pre-major": true,
77
"include-v-in-tag": false,
88
"extra-files": ["ldai/__init__.py", "PROVENANCE.md"],
99
"include-component-in-tag": false

0 commit comments

Comments
 (0)