Skip to content

Commit b2eb9ab

Browse files
committed
Return config and tracker separately
1 parent af8f3d6 commit b2eb9ab

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

ldai/client.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Dict, List, Literal, Optional
2+
from typing import Any, Dict, List, Literal, Optional, Tuple
33

44
import chevron
55
from ldclient import Context
@@ -105,10 +105,7 @@ def to_dict(self) -> dict:
105105

106106

107107
@dataclass(frozen=True)
108-
class DefaultAIConfig:
109-
"""
110-
The default values when evaluating an AI configuration.
111-
"""
108+
class AIConfig:
112109
enabled: Optional[bool] = None
113110
model: Optional[ModelConfig] = None
114111
messages: Optional[List[LDMessage]] = None
@@ -128,15 +125,6 @@ def to_dict(self) -> dict:
128125
}
129126

130127

131-
@dataclass(frozen=True)
132-
class AIConfig:
133-
tracker: LDAIConfigTracker
134-
enabled: bool
135-
model: Optional[ModelConfig] = None
136-
messages: Optional[List[LDMessage]] = None
137-
provider: Optional[ProviderConfig] = None
138-
139-
140128
class LDAIClient:
141129
"""The LaunchDarkly AI SDK client object."""
142130

@@ -147,17 +135,17 @@ def config(
147135
self,
148136
key: str,
149137
context: Context,
150-
default_value: DefaultAIConfig,
138+
default_value: AIConfig,
151139
variables: Optional[Dict[str, Any]] = None,
152-
) -> AIConfig:
140+
) -> Tuple[AIConfig, LDAIConfigTracker]:
153141
"""
154142
Get the value of a model configuration.
155143
156144
:param key: The key of the model configuration.
157145
:param context: The context to evaluate the model configuration in.
158146
:param default_value: The default value of the model configuration.
159147
:param variables: Additional variables for the model configuration.
160-
:return: The value of the model configuration.
148+
:return: The value of the model configuration along with a tracker used for gathering metrics.
161149
"""
162150
variation = self.client.variation(key, context, default_value.to_dict())
163151

@@ -195,20 +183,23 @@ def config(
195183
custom=custom
196184
)
197185

186+
tracker = LDAIConfigTracker(
187+
self.client,
188+
variation.get('_ldMeta', {}).get('versionKey', ''),
189+
key,
190+
context,
191+
)
192+
198193
enabled = variation.get('_ldMeta', {}).get('enabled', False)
199-
return AIConfig(
200-
tracker=LDAIConfigTracker(
201-
self.client,
202-
variation.get('_ldMeta', {}).get('versionKey', ''),
203-
key,
204-
context,
205-
),
194+
config = AIConfig(
206195
enabled=bool(enabled),
207196
model=model,
208197
messages=messages,
209198
provider=provider_config,
210199
)
211200

201+
return config, tracker
202+
212203
def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
213204
"""
214205
Interpolate the template with the given variables.

ldai/testing/test_model_config.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ldclient import Config, Context, LDClient
33
from ldclient.integrations.test_data import TestData
44

5-
from ldai.client import DefaultAIConfig, LDAIClient, LDMessage, ModelConfig
5+
from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig
66

77

88
@pytest.fixture
@@ -121,14 +121,14 @@ def test_model_config_handles_custom():
121121

122122
def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
123123
context = Context.create('user-key')
124-
default_value = DefaultAIConfig(
124+
default_value = AIConfig(
125125
enabled=True,
126126
model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}),
127127
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
128128
)
129129
variables = {'name': 'World'}
130130

131-
config = ldai_client.config('missing-flag', context, default_value, variables)
131+
config, _ = ldai_client.config('missing-flag', context, default_value, variables)
132132

133133
assert config.messages is not None
134134
assert len(config.messages) > 0
@@ -143,14 +143,14 @@ def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
143143

144144
def test_model_config_interpolation(ldai_client: LDAIClient):
145145
context = Context.create('user-key')
146-
default_value = DefaultAIConfig(
146+
default_value = AIConfig(
147147
enabled=True,
148148
model=ModelConfig('fakeModel'),
149149
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
150150
)
151151
variables = {'name': 'World'}
152152

153-
config = ldai_client.config('model-config', context, default_value, variables)
153+
config, _ = ldai_client.config('model-config', context, default_value, variables)
154154

155155
assert config.messages is not None
156156
assert len(config.messages) > 0
@@ -165,9 +165,9 @@ def test_model_config_interpolation(ldai_client: LDAIClient):
165165

166166
def test_model_config_no_variables(ldai_client: LDAIClient):
167167
context = Context.create('user-key')
168-
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
168+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
169169

170-
config = ldai_client.config('model-config', context, default_value, {})
170+
config, _ = ldai_client.config('model-config', context, default_value, {})
171171

172172
assert config.messages is not None
173173
assert len(config.messages) > 0
@@ -182,21 +182,21 @@ def test_model_config_no_variables(ldai_client: LDAIClient):
182182

183183
def test_provider_config_handling(ldai_client: LDAIClient):
184184
context = Context.builder('user-key').name("Sandy").build()
185-
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
185+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
186186
variables = {'name': 'World'}
187187

188-
config = ldai_client.config('model-config', context, default_value, variables)
188+
config, _ = ldai_client.config('model-config', context, default_value, variables)
189189

190190
assert config.provider is not None
191191
assert config.provider.id == 'fakeProvider'
192192

193193

194194
def test_context_interpolation(ldai_client: LDAIClient):
195195
context = Context.builder('user-key').name("Sandy").build()
196-
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
196+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
197197
variables = {'name': 'World'}
198198

199-
config = ldai_client.config(
199+
config, _ = ldai_client.config(
200200
'ctx-interpolation', context, default_value, variables
201201
)
202202

@@ -214,10 +214,10 @@ def test_context_interpolation(ldai_client: LDAIClient):
214214

215215
def test_model_config_multiple(ldai_client: LDAIClient):
216216
context = Context.create('user-key')
217-
default_value = DefaultAIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
217+
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
218218
variables = {'name': 'World', 'day': 'Monday'}
219219

220-
config = ldai_client.config(
220+
config, _ = ldai_client.config(
221221
'multiple-messages', context, default_value, variables
222222
)
223223

@@ -235,9 +235,9 @@ def test_model_config_multiple(ldai_client: LDAIClient):
235235

236236
def test_model_config_disabled(ldai_client: LDAIClient):
237237
context = Context.create('user-key')
238-
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
238+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
239239

240-
config = ldai_client.config('off-config', context, default_value, {})
240+
config, _ = ldai_client.config('off-config', context, default_value, {})
241241

242242
assert config.model is not None
243243
assert config.enabled is False
@@ -248,9 +248,9 @@ def test_model_config_disabled(ldai_client: LDAIClient):
248248

249249
def test_model_initial_config_disabled(ldai_client: LDAIClient):
250250
context = Context.create('user-key')
251-
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
251+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
252252

253-
config = ldai_client.config('initial-config-disabled', context, default_value, {})
253+
config, _ = ldai_client.config('initial-config-disabled', context, default_value, {})
254254

255255
assert config.enabled is False
256256
assert config.model is None
@@ -260,9 +260,9 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient):
260260

261261
def test_model_initial_config_enabled(ldai_client: LDAIClient):
262262
context = Context.create('user-key')
263-
default_value = DefaultAIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
263+
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
264264

265-
config = ldai_client.config('initial-config-enabled', context, default_value, {})
265+
config, _ = ldai_client.config('initial-config-enabled', context, default_value, {})
266266

267267
assert config.enabled is True
268268
assert config.model is None

0 commit comments

Comments
 (0)