Skip to content

Commit d62a779

Browse files
authored
feat: Update AI tracker to include model & provider name for metrics generation (#58)
**Requirements** - [X] I have added test coverage for new or changed functionality - [ ] I have followed the repository's [pull request submission guidelines](../blob/main/CONTRIBUTING.md#submitting-pull-requests) - [ ] I have validated my changes against all supported platform versions **Related issues** Provide links to any issues in this repository or elsewhere relating to this pull request. **Describe the solution you've provided** Provide a clear and concise description of what you expect to happen. **Describe alternatives you've considered** Provide a clear and concise description of any alternative solutions or features you've considered. **Additional context** spec PR here: launchdarkly/sdk-specs#112
2 parents ed02047 + 05c15cf commit d62a779

File tree

3 files changed

+53
-39
lines changed

3 files changed

+53
-39
lines changed

ldai/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ def __evaluate(
402402
variation.get('_ldMeta', {}).get('variationKey', ''),
403403
key,
404404
int(variation.get('_ldMeta', {}).get('version', 1)),
405+
model.name if model else '',
406+
provider_config.name if provider_config else '',
405407
context,
406408
)
407409

ldai/testing/test_tracker.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def client(td: TestData) -> LDClient:
4242

4343
def test_summary_starts_empty(client: LDClient):
4444
context = Context.create("user-key")
45-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, context)
45+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, "fakeModel", "fakeProvider", context)
4646

4747
assert tracker.get_summary().duration is None
4848
assert tracker.get_summary().feedback is None
@@ -52,13 +52,13 @@ def test_summary_starts_empty(client: LDClient):
5252

5353
def test_tracks_duration(client: LDClient):
5454
context = Context.create("user-key")
55-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
55+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
5656
tracker.track_duration(100)
5757

5858
client.track.assert_called_with( # type: ignore
5959
"$ld:ai:duration:total",
6060
context,
61-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
61+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
6262
100,
6363
)
6464

@@ -67,7 +67,7 @@ def test_tracks_duration(client: LDClient):
6767

6868
def test_tracks_duration_of(client: LDClient):
6969
context = Context.create("user-key")
70-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
70+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
7171
tracker.track_duration_of(lambda: sleep(0.01))
7272

7373
calls = client.track.mock_calls # type: ignore
@@ -79,19 +79,21 @@ def test_tracks_duration_of(client: LDClient):
7979
"variationKey": "variation-key",
8080
"configKey": "config-key",
8181
"version": 3,
82+
"modelName": "fakeModel",
83+
"providerName": "fakeProvider",
8284
}
8385
assert calls[0].args[3] == pytest.approx(10, rel=10)
8486

8587

8688
def test_tracks_time_to_first_token(client: LDClient):
8789
context = Context.create("user-key")
88-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
90+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
8991
tracker.track_time_to_first_token(100)
9092

9193
client.track.assert_called_with( # type: ignore
9294
"$ld:ai:tokens:ttf",
9395
context,
94-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
96+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
9597
100,
9698
)
9799

@@ -100,7 +102,7 @@ def test_tracks_time_to_first_token(client: LDClient):
100102

101103
def test_tracks_duration_of_with_exception(client: LDClient):
102104
context = Context.create("user-key")
103-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
105+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
104106

105107
def sleep_and_throw():
106108
sleep(0.01)
@@ -121,13 +123,15 @@ def sleep_and_throw():
121123
"variationKey": "variation-key",
122124
"configKey": "config-key",
123125
"version": 3,
126+
"modelName": "fakeModel",
127+
"providerName": "fakeProvider",
124128
}
125129
assert calls[0].args[3] == pytest.approx(10, rel=10)
126130

127131

128132
def test_tracks_token_usage(client: LDClient):
129133
context = Context.create("user-key")
130-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
134+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
131135

132136
tokens = TokenUsage(300, 200, 100)
133137
tracker.track_tokens(tokens)
@@ -136,19 +140,19 @@ def test_tracks_token_usage(client: LDClient):
136140
call(
137141
"$ld:ai:tokens:total",
138142
context,
139-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
143+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
140144
300,
141145
),
142146
call(
143147
"$ld:ai:tokens:input",
144148
context,
145-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
149+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
146150
200,
147151
),
148152
call(
149153
"$ld:ai:tokens:output",
150154
context,
151-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
155+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
152156
100,
153157
),
154158
]
@@ -160,7 +164,7 @@ def test_tracks_token_usage(client: LDClient):
160164

161165
def test_tracks_bedrock_metrics(client: LDClient):
162166
context = Context.create("user-key")
163-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
167+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
164168

165169
bedrock_result = {
166170
"ResponseMetadata": {"HTTPStatusCode": 200},
@@ -179,31 +183,31 @@ def test_tracks_bedrock_metrics(client: LDClient):
179183
call(
180184
"$ld:ai:generation:success",
181185
context,
182-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
186+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
183187
1,
184188
),
185189
call(
186190
"$ld:ai:duration:total",
187191
context,
188-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
192+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
189193
50,
190194
),
191195
call(
192196
"$ld:ai:tokens:total",
193197
context,
194-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
198+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
195199
330,
196200
),
197201
call(
198202
"$ld:ai:tokens:input",
199203
context,
200-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
204+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
201205
220,
202206
),
203207
call(
204208
"$ld:ai:tokens:output",
205209
context,
206-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
210+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
207211
110,
208212
),
209213
]
@@ -217,7 +221,7 @@ def test_tracks_bedrock_metrics(client: LDClient):
217221

218222
def test_tracks_bedrock_metrics_with_error(client: LDClient):
219223
context = Context.create("user-key")
220-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
224+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
221225

222226
bedrock_result = {
223227
"ResponseMetadata": {"HTTPStatusCode": 500},
@@ -236,31 +240,31 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient):
236240
call(
237241
"$ld:ai:generation:error",
238242
context,
239-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
243+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
240244
1,
241245
),
242246
call(
243247
"$ld:ai:duration:total",
244248
context,
245-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
249+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
246250
50,
247251
),
248252
call(
249253
"$ld:ai:tokens:total",
250254
context,
251-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
255+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
252256
330,
253257
),
254258
call(
255259
"$ld:ai:tokens:input",
256260
context,
257-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
261+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
258262
220,
259263
),
260264
call(
261265
"$ld:ai:tokens:output",
262266
context,
263-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
267+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
264268
110,
265269
),
266270
]
@@ -274,7 +278,7 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient):
274278

275279
def test_tracks_openai_metrics(client: LDClient):
276280
context = Context.create("user-key")
277-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
281+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
278282

279283
class Result:
280284
def __init__(self):
@@ -294,25 +298,25 @@ def to_dict(self):
294298
call(
295299
"$ld:ai:generation:success",
296300
context,
297-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
301+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
298302
1,
299303
),
300304
call(
301305
"$ld:ai:tokens:total",
302306
context,
303-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
307+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
304308
330,
305309
),
306310
call(
307311
"$ld:ai:tokens:input",
308312
context,
309-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
313+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
310314
220,
311315
),
312316
call(
313317
"$ld:ai:tokens:output",
314318
context,
315-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
319+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
316320
110,
317321
),
318322
]
@@ -324,7 +328,7 @@ def to_dict(self):
324328

325329
def test_tracks_openai_metrics_with_exception(client: LDClient):
326330
context = Context.create("user-key")
327-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
331+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
328332

329333
def raise_exception():
330334
raise ValueError("Something went wrong")
@@ -339,7 +343,7 @@ def raise_exception():
339343
call(
340344
"$ld:ai:generation:error",
341345
context,
342-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
346+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
343347
1,
344348
),
345349
]
@@ -358,29 +362,29 @@ def raise_exception():
358362
)
359363
def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str):
360364
context = Context.create("user-key")
361-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
365+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
362366

363367
tracker.track_feedback({"kind": kind})
364368

365369
client.track.assert_called_with( # type: ignore
366370
f"$ld:ai:feedback:user:{label}",
367371
context,
368-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
372+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
369373
1,
370374
)
371375
assert tracker.get_summary().feedback == {"kind": kind}
372376

373377

374378
def test_tracks_success(client: LDClient):
375379
context = Context.create("user-key")
376-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
380+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
377381
tracker.track_success()
378382

379383
calls = [
380384
call(
381385
"$ld:ai:generation:success",
382386
context,
383-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
387+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
384388
1,
385389
),
386390
]
@@ -392,14 +396,14 @@ def test_tracks_success(client: LDClient):
392396

393397
def test_tracks_error(client: LDClient):
394398
context = Context.create("user-key")
395-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
399+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
396400
tracker.track_error()
397401

398402
calls = [
399403
call(
400404
"$ld:ai:generation:error",
401405
context,
402-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
406+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
403407
1,
404408
),
405409
]
@@ -411,21 +415,21 @@ def test_tracks_error(client: LDClient):
411415

412416
def test_error_overwrites_success(client: LDClient):
413417
context = Context.create("user-key")
414-
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
418+
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
415419
tracker.track_success()
416420
tracker.track_error()
417421

418422
calls = [
419423
call(
420424
"$ld:ai:generation:success",
421425
context,
422-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
426+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
423427
1,
424428
),
425429
call(
426430
"$ld:ai:generation:error",
427431
context,
428-
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
432+
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
429433
1,
430434
),
431435
]

ldai/tracker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def __init__(
7474
variation_key: str,
7575
config_key: str,
7676
version: int,
77+
model_name: str,
78+
provider_name: str,
7779
context: Context,
7880
):
7981
"""
@@ -83,12 +85,16 @@ def __init__(
8385
:param variation_key: Variation key for tracking.
8486
:param config_key: Configuration key for tracking.
8587
:param version: Version of the variation.
88+
:param model_name: Name of the model used.
89+
:param provider_name: Name of the provider used.
8690
:param context: Context for evaluation.
8791
"""
8892
self._ld_client = ld_client
8993
self._variation_key = variation_key
9094
self._config_key = config_key
9195
self._version = version
96+
self._model_name = model_name
97+
self._provider_name = provider_name
9298
self._context = context
9399
self._summary = LDAIMetricSummary()
94100

@@ -102,6 +108,8 @@ def __get_track_data(self):
102108
"variationKey": self._variation_key,
103109
"configKey": self._config_key,
104110
"version": self._version,
111+
"modelName": self._model_name,
112+
"providerName": self._provider_name,
105113
}
106114

107115
def track_duration(self, duration: int) -> None:

0 commit comments

Comments
 (0)