Skip to content

Commit af8f3d6

Browse files
committed
fix: Distinguish AIConfig from DefaultAIConfig
Consumers of this SDK are responsible for providing some default value to the `config` method. If that type is of `AIConfig`, customers have the option of providing a tracker instance that will never be useful. To help discourage this, we are making two separate types -- one to represent the default ai config, and one to represent the actual selected config.
1 parent 8e2d409 commit af8f3d6

File tree

4 files changed

+99
-36
lines changed

4 files changed

+99
-36
lines changed

ldai/client.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,26 @@
22
from typing import Any, Dict, List, Literal, Optional
33

44
import chevron
5-
from dataclasses_json import dataclass_json
65
from ldclient import Context
76
from ldclient.client import LDClient
87

98
from ldai.tracker import LDAIConfigTracker
109

1110

12-
@dataclass_json
1311
@dataclass
1412
class LDMessage:
1513
role: Literal['system', 'user', 'assistant']
1614
content: str
1715

16+
def to_dict(self) -> dict:
17+
"""
18+
Render the given message as a dictionary object.
19+
"""
20+
return {
21+
'role': self.role,
22+
'content': self.content,
23+
}
24+
1825

1926
class ModelConfig:
2027
"""
@@ -62,6 +69,16 @@ def get_custom(self, key: str) -> Any:
6269

6370
return self._custom.get(key)
6471

72+
def to_dict(self) -> dict:
73+
"""
74+
Render the given model config as a dictionary object.
75+
"""
76+
return {
77+
'id': self._id,
78+
'parameters': self._parameters,
79+
'custom': self._custom,
80+
}
81+
6582

6683
class ProviderConfig:
6784
"""
@@ -78,14 +95,46 @@ def id(self) -> str:
7895
"""
7996
return self._id
8097

98+
def to_dict(self) -> dict:
99+
"""
100+
Render the given provider config as a dictionary object.
101+
"""
102+
return {
103+
'id': self._id,
104+
}
105+
106+
107+
@dataclass(frozen=True)
108+
class DefaultAIConfig:
109+
"""
110+
The default values when evaluating an AI configuration.
111+
"""
112+
enabled: Optional[bool] = None
113+
model: Optional[ModelConfig] = None
114+
messages: Optional[List[LDMessage]] = None
115+
provider: Optional[ProviderConfig] = None
116+
117+
def to_dict(self) -> dict:
118+
"""
119+
Render the given default values as an AIConfig-compatible dictionary object.
120+
"""
121+
return {
122+
'_ldMeta': {
123+
'enabled': self.enabled or False,
124+
},
125+
'model': self.model.to_dict() if self.model else None,
126+
'messages': [message.to_dict() for message in self.messages] if self.messages else None,
127+
'provider': self.provider.to_dict() if self.provider else None,
128+
}
129+
81130

131+
@dataclass(frozen=True)
82132
class AIConfig:
83-
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None):
84-
self.tracker = tracker
85-
self.enabled = enabled
86-
self.model = model
87-
self.messages = messages
88-
self.provider = provider
133+
tracker: LDAIConfigTracker
134+
enabled: bool
135+
model: Optional[ModelConfig] = None
136+
messages: Optional[List[LDMessage]] = None
137+
provider: Optional[ProviderConfig] = None
89138

90139

91140
class LDAIClient:
@@ -98,7 +147,7 @@ def config(
98147
self,
99148
key: str,
100149
context: Context,
101-
default_value: AIConfig,
150+
default_value: DefaultAIConfig,
102151
variables: Optional[Dict[str, Any]] = None,
103152
) -> AIConfig:
104153
"""
@@ -110,7 +159,7 @@ def config(
110159
:param variables: Additional variables for the model configuration.
111160
:return: The value of the model configuration.
112161
"""
113-
variation = self.client.variation(key, context, default_value)
162+
variation = self.client.variation(key, context, default_value.to_dict())
114163

115164
all_variables = {}
116165
if variables:

ldai/testing/test_model_config.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
from ldclient import Config, Context, LDClient
33
from ldclient.integrations.test_data import TestData
44

5-
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
6-
from ldai.tracker import LDAIConfigTracker
5+
from ldai.client import DefaultAIConfig, LDAIClient, LDMessage, ModelConfig
76

87

98
@pytest.fixture
@@ -98,11 +97,6 @@ def client(td: TestData) -> LDClient:
9897
return LDClient(config=config)
9998

10099

101-
@pytest.fixture
102-
def tracker(client: LDClient) -> LDAIConfigTracker:
103-
return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key'))
104-
105-
106100
@pytest.fixture
107101
def ldai_client(client: LDClient) -> LDAIClient:
108102
return LDAIClient(client)
@@ -125,10 +119,31 @@ def test_model_config_handles_custom():
125119
assert model.get_custom('id') is None
126120

127121

128-
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
122+
def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
123+
context = Context.create('user-key')
124+
default_value = DefaultAIConfig(
125+
enabled=True,
126+
model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}),
127+
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
128+
)
129+
variables = {'name': 'World'}
130+
131+
config = ldai_client.config('missing-flag', context, default_value, variables)
132+
133+
assert config.messages is not None
134+
assert len(config.messages) > 0
135+
assert config.messages[0].content == 'Hello, World!'
136+
assert config.enabled is True
137+
138+
assert config.model is not None
139+
assert config.model.id == 'fakeModel'
140+
assert config.model.get_parameter('temperature') == 0.5
141+
assert config.model.get_parameter('maxTokens') == 4096
142+
143+
144+
def test_model_config_interpolation(ldai_client: LDAIClient):
129145
context = Context.create('user-key')
130-
default_value = AIConfig(
131-
tracker=tracker,
146+
default_value = DefaultAIConfig(
132147
enabled=True,
133148
model=ModelConfig('fakeModel'),
134149
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
@@ -148,9 +163,9 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
148163
assert config.model.get_parameter('maxTokens') == 4096
149164

150165

151-
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
166+
def test_model_config_no_variables(ldai_client: LDAIClient):
152167
context = Context.create('user-key')
153-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
168+
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
154169

155170
config = ldai_client.config('model-config', context, default_value, {})
156171

@@ -165,9 +180,9 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
165180
assert config.model.get_parameter('maxTokens') == 4096
166181

167182

168-
def test_provider_config_handling(ldai_client: LDAIClient, tracker):
183+
def test_provider_config_handling(ldai_client: LDAIClient):
169184
context = Context.builder('user-key').name("Sandy").build()
170-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
185+
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
171186
variables = {'name': 'World'}
172187

173188
config = ldai_client.config('model-config', context, default_value, variables)
@@ -176,9 +191,9 @@ def test_provider_config_handling(ldai_client: LDAIClient, tracker):
176191
assert config.provider.id == 'fakeProvider'
177192

178193

179-
def test_context_interpolation(ldai_client: LDAIClient, tracker):
194+
def test_context_interpolation(ldai_client: LDAIClient):
180195
context = Context.builder('user-key').name("Sandy").build()
181-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
196+
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
182197
variables = {'name': 'World'}
183198

184199
config = ldai_client.config(
@@ -197,9 +212,9 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):
197212
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'
198213

199214

200-
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
215+
def test_model_config_multiple(ldai_client: LDAIClient):
201216
context = Context.create('user-key')
202-
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
217+
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
203218
variables = {'name': 'World', 'day': 'Monday'}
204219

205220
config = ldai_client.config(
@@ -218,9 +233,9 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):
218233
assert config.model.get_parameter('maxTokens') == 8192
219234

220235

221-
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
236+
def test_model_config_disabled(ldai_client: LDAIClient):
222237
context = Context.create('user-key')
223-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
238+
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
224239

225240
config = ldai_client.config('off-config', context, default_value, {})
226241

@@ -231,9 +246,9 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker):
231246
assert config.model.get_parameter('maxTokens') is None
232247

233248

234-
def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
249+
def test_model_initial_config_disabled(ldai_client: LDAIClient):
235250
context = Context.create('user-key')
236-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
251+
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
237252

238253
config = ldai_client.config('initial-config-disabled', context, default_value, {})
239254

@@ -243,9 +258,9 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
243258
assert config.provider is None
244259

245260

246-
def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker):
261+
def test_model_initial_config_enabled(ldai_client: LDAIClient):
247262
context = Context.create('user-key')
248-
default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[])
263+
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
249264

250265
config = ldai_client.config('initial-config-enabled', context, default_value, {})
251266

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ exclude = [
3030
python = ">=3.8,<4"
3131
launchdarkly-server-sdk = ">=9.4.0"
3232
chevron = "=0.14.0"
33-
dataclasses-json = "^0.6.7"
3433

3534

3635
[tool.poetry.group.dev.dependencies]

release-please-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
".": {
44
"release-type": "python",
55
"versioning": "default",
6-
"release-as": "0.3.0",
6+
"bump-minor-pre-major": true,
77
"include-v-in-tag": false,
88
"extra-files": ["ldai/__init__.py", "PROVENANCE.md"],
99
"include-component-in-tag": false

0 commit comments

Comments
 (0)