Skip to content

Commit 43d2b54

Browse files
committed
Refactor AI configuration tests to use AICompletionConfigDefault and update async tracking for OpenAI metrics
1 parent f3c792d commit 43d2b54

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

ldai/providers/ai_provider_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ async def _try_create_provider(
113113
)
114114
return None
115115

116+
# TODO: REL-10773 OpenAI provider
117+
# TODO: REL-10776 Vercel provider
116118
# For future external providers, use dynamic import
117119
provider_mappings = {
118120
# 'openai': ('launchdarkly_server_sdk_ai_openai', 'OpenAIProvider'),

ldai/testing/test_model_config.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from ldclient import Config, Context, LDClient
33
from ldclient.integrations.test_data import TestData
44

5-
from ldai import AIConfig, LDAIClient, LDMessage, ModelConfig
5+
from ldai import AICompletionConfigDefault, LDAIClient, LDMessage, ModelConfig
66

77

88
@pytest.fixture
@@ -133,14 +133,14 @@ def test_model_config_handles_custom():
133133

134134
def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
135135
context = Context.create('user-key')
136-
default_value = AIConfig(
136+
default_value = AICompletionConfigDefault(
137137
enabled=True,
138138
model=ModelConfig('fakeModel', parameters={'temperature': 0.5, 'maxTokens': 4096}),
139139
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
140140
)
141141
variables = {'name': 'World'}
142142

143-
config, _ = ldai_client.config('missing-flag', context, default_value, variables)
143+
config = ldai_client.config('missing-flag', context, default_value, variables)
144144

145145
assert config.messages is not None
146146
assert len(config.messages) > 0
@@ -155,14 +155,14 @@ def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
155155

156156
def test_model_config_interpolation(ldai_client: LDAIClient):
157157
context = Context.create('user-key')
158-
default_value = AIConfig(
158+
default_value = AICompletionConfigDefault(
159159
enabled=True,
160160
model=ModelConfig('fakeModel'),
161161
messages=[LDMessage(role='system', content='Hello, {{name}}!')],
162162
)
163163
variables = {'name': 'World'}
164164

165-
config, _ = ldai_client.config('model-config', context, default_value, variables)
165+
config = ldai_client.config('model-config', context, default_value, variables)
166166

167167
assert config.messages is not None
168168
assert len(config.messages) > 0
@@ -177,9 +177,9 @@ def test_model_config_interpolation(ldai_client: LDAIClient):
177177

178178
def test_model_config_no_variables(ldai_client: LDAIClient):
179179
context = Context.create('user-key')
180-
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
180+
default_value = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])
181181

182-
config, _ = ldai_client.config('model-config', context, default_value, {})
182+
config = ldai_client.config('model-config', context, default_value, {})
183183

184184
assert config.messages is not None
185185
assert len(config.messages) > 0
@@ -194,21 +194,21 @@ def test_model_config_no_variables(ldai_client: LDAIClient):
194194

195195
def test_provider_config_handling(ldai_client: LDAIClient):
196196
context = Context.builder('user-key').name("Sandy").build()
197-
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
197+
default_value = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])
198198
variables = {'name': 'World'}
199199

200-
config, _ = ldai_client.config('model-config', context, default_value, variables)
200+
config = ldai_client.config('model-config', context, default_value, variables)
201201

202202
assert config.provider is not None
203203
assert config.provider.name == 'fakeProvider'
204204

205205

206206
def test_context_interpolation(ldai_client: LDAIClient):
207207
context = Context.builder('user-key').name("Sandy").set('last', 'Beaches').build()
208-
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
208+
default_value = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])
209209
variables = {'name': 'World'}
210210

211-
config, _ = ldai_client.config(
211+
config = ldai_client.config(
212212
'ctx-interpolation', context, default_value, variables
213213
)
214214

@@ -228,10 +228,10 @@ def test_multi_context_interpolation(ldai_client: LDAIClient):
228228
user_context = Context.builder('user-key').name("Sandy").build()
229229
org_context = Context.builder('org-key').kind('org').name("LaunchDarkly").set('shortname', 'LD').build()
230230
context = Context.multi_builder().add(user_context).add(org_context).build()
231-
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
231+
default_value = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])
232232
variables = {'name': 'World'}
233233

234-
config, _ = ldai_client.config(
234+
config = ldai_client.config(
235235
'multi-ctx-interpolation', context, default_value, variables
236236
)
237237

@@ -249,10 +249,10 @@ def test_multi_context_interpolation(ldai_client: LDAIClient):
249249

250250
def test_model_config_multiple(ldai_client: LDAIClient):
251251
context = Context.create('user-key')
252-
default_value = AIConfig(enabled=True, model=ModelConfig('fake-model'), messages=[])
252+
default_value = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[])
253253
variables = {'name': 'World', 'day': 'Monday'}
254254

255-
config, _ = ldai_client.config(
255+
config = ldai_client.config(
256256
'multiple-messages', context, default_value, variables
257257
)
258258

@@ -270,9 +270,9 @@ def test_model_config_multiple(ldai_client: LDAIClient):
270270

271271
def test_model_config_disabled(ldai_client: LDAIClient):
272272
context = Context.create('user-key')
273-
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
273+
default_value = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake-model'), messages=[])
274274

275-
config, _ = ldai_client.config('off-config', context, default_value, {})
275+
config = ldai_client.config('off-config', context, default_value, {})
276276

277277
assert config.model is not None
278278
assert config.enabled is False
@@ -283,9 +283,9 @@ def test_model_config_disabled(ldai_client: LDAIClient):
283283

284284
def test_model_initial_config_disabled(ldai_client: LDAIClient):
285285
context = Context.create('user-key')
286-
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
286+
default_value = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake-model'), messages=[])
287287

288-
config, _ = ldai_client.config('initial-config-disabled', context, default_value, {})
288+
config = ldai_client.config('initial-config-disabled', context, default_value, {})
289289

290290
assert config.enabled is False
291291
assert config.model is None
@@ -295,9 +295,9 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient):
295295

296296
def test_model_initial_config_enabled(ldai_client: LDAIClient):
297297
context = Context.create('user-key')
298-
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
298+
default_value = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake-model'), messages=[])
299299

300-
config, _ = ldai_client.config('initial-config-enabled', context, default_value, {})
300+
config = ldai_client.config('initial-config-enabled', context, default_value, {})
301301

302302
assert config.enabled is True
303303
assert config.model is None
@@ -318,9 +318,9 @@ def test_config_method_tracking(ldai_client: LDAIClient):
318318

319319
client = LDAIClient(mock_client)
320320
context = Context.create('user-key')
321-
default_value = AIConfig(enabled=False, model=ModelConfig('fake-model'), messages=[])
321+
default_value = AICompletionConfigDefault(enabled=False, model=ModelConfig('fake-model'), messages=[])
322322

323-
config, tracker = client.config('test-config-key', context, default_value)
323+
config = client.config('test-config-key', context, default_value)
324324

325325
mock_client.track.assert_called_once_with(
326326
'$ld:ai:config:function:single',

ldai/testing/test_tracker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient):
276276
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
277277

278278

279-
def test_tracks_openai_metrics(client: LDClient):
279+
@pytest.mark.asyncio
280+
async def test_tracks_openai_metrics(client: LDClient):
280281
context = Context.create("user-key")
281282
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
282283

@@ -292,7 +293,10 @@ def to_dict(self):
292293
"completion_tokens": 110,
293294
}
294295

295-
tracker.track_openai_metrics(lambda: Result())
296+
async def get_result():
297+
return Result()
298+
299+
await tracker.track_openai_metrics(get_result)
296300

297301
calls = [
298302
call(
@@ -326,15 +330,16 @@ def to_dict(self):
326330
assert tracker.get_summary().usage == TokenUsage(330, 220, 110)
327331

328332

329-
def test_tracks_openai_metrics_with_exception(client: LDClient):
333+
@pytest.mark.asyncio
334+
async def test_tracks_openai_metrics_with_exception(client: LDClient):
330335
context = Context.create("user-key")
331336
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
332337

333-
def raise_exception():
338+
async def raise_exception():
334339
raise ValueError("Something went wrong")
335340

336341
try:
337-
tracker.track_openai_metrics(raise_exception)
342+
await tracker.track_openai_metrics(raise_exception)
338343
assert False, "Should have thrown an exception"
339344
except ValueError:
340345
pass

0 commit comments

Comments
 (0)