Skip to content

feat: Update AI tracker to include model & provider name for metrics generation #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ def __evaluate(
variation.get('_ldMeta', {}).get('variationKey', ''),
key,
int(variation.get('_ldMeta', {}).get('version', 1)),
model.name if model else '',
provider_config.name if provider_config else '',
context,
)

Expand Down
82 changes: 43 additions & 39 deletions ldai/testing/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def client(td: TestData) -> LDClient:

def test_summary_starts_empty(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, "fakeModel", "fakeProvider", context)

assert tracker.get_summary().duration is None
assert tracker.get_summary().feedback is None
Expand All @@ -52,13 +52,13 @@ def test_summary_starts_empty(client: LDClient):

def test_tracks_duration(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_duration(100)

client.track.assert_called_with( # type: ignore
"$ld:ai:duration:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
100,
)

Expand All @@ -67,7 +67,7 @@ def test_tracks_duration(client: LDClient):

def test_tracks_duration_of(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_duration_of(lambda: sleep(0.01))

calls = client.track.mock_calls # type: ignore
Expand All @@ -79,19 +79,21 @@ def test_tracks_duration_of(client: LDClient):
"variationKey": "variation-key",
"configKey": "config-key",
"version": 3,
"modelName": "fakeModel",
"providerName": "fakeProvider",
}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_time_to_first_token(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_time_to_first_token(100)

client.track.assert_called_with( # type: ignore
"$ld:ai:tokens:ttf",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
100,
)

Expand All @@ -100,7 +102,7 @@ def test_tracks_time_to_first_token(client: LDClient):

def test_tracks_duration_of_with_exception(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

def sleep_and_throw():
sleep(0.01)
Expand All @@ -121,13 +123,15 @@ def sleep_and_throw():
"variationKey": "variation-key",
"configKey": "config-key",
"version": 3,
"modelName": "fakeModel",
"providerName": "fakeProvider",
}
assert calls[0].args[3] == pytest.approx(10, rel=10)


def test_tracks_token_usage(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

tokens = TokenUsage(300, 200, 100)
tracker.track_tokens(tokens)
Expand All @@ -136,19 +140,19 @@ def test_tracks_token_usage(client: LDClient):
call(
"$ld:ai:tokens:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
300,
),
call(
"$ld:ai:tokens:input",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
200,
),
call(
"$ld:ai:tokens:output",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
100,
),
]
Expand All @@ -160,7 +164,7 @@ def test_tracks_token_usage(client: LDClient):

def test_tracks_bedrock_metrics(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

bedrock_result = {
"ResponseMetadata": {"HTTPStatusCode": 200},
Expand All @@ -179,31 +183,31 @@ def test_tracks_bedrock_metrics(client: LDClient):
call(
"$ld:ai:generation:success",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
call(
"$ld:ai:duration:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
50,
),
call(
"$ld:ai:tokens:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
330,
),
call(
"$ld:ai:tokens:input",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
220,
),
call(
"$ld:ai:tokens:output",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
110,
),
]
Expand All @@ -217,7 +221,7 @@ def test_tracks_bedrock_metrics(client: LDClient):

def test_tracks_bedrock_metrics_with_error(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

bedrock_result = {
"ResponseMetadata": {"HTTPStatusCode": 500},
Expand All @@ -236,31 +240,31 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient):
call(
"$ld:ai:generation:error",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
call(
"$ld:ai:duration:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
50,
),
call(
"$ld:ai:tokens:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
330,
),
call(
"$ld:ai:tokens:input",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
220,
),
call(
"$ld:ai:tokens:output",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
110,
),
]
Expand All @@ -274,7 +278,7 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient):

def test_tracks_openai_metrics(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

class Result:
def __init__(self):
Expand All @@ -294,25 +298,25 @@ def to_dict(self):
call(
"$ld:ai:generation:success",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
call(
"$ld:ai:tokens:total",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
330,
),
call(
"$ld:ai:tokens:input",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
220,
),
call(
"$ld:ai:tokens:output",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
110,
),
]
Expand All @@ -324,7 +328,7 @@ def to_dict(self):

def test_tracks_openai_metrics_with_exception(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

def raise_exception():
raise ValueError("Something went wrong")
Expand All @@ -339,7 +343,7 @@ def raise_exception():
call(
"$ld:ai:generation:error",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
]
Expand All @@ -358,29 +362,29 @@ def raise_exception():
)
def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)

tracker.track_feedback({"kind": kind})

client.track.assert_called_with( # type: ignore
f"$ld:ai:feedback:user:{label}",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
)
assert tracker.get_summary().feedback == {"kind": kind}


def test_tracks_success(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_success()

calls = [
call(
"$ld:ai:generation:success",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
]
Expand All @@ -392,14 +396,14 @@ def test_tracks_success(client: LDClient):

def test_tracks_error(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_error()

calls = [
call(
"$ld:ai:generation:error",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
]
Expand All @@ -411,21 +415,21 @@ def test_tracks_error(client: LDClient):

def test_error_overwrites_success(client: LDClient):
context = Context.create("user-key")
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context)
tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context)
tracker.track_success()
tracker.track_error()

calls = [
call(
"$ld:ai:generation:success",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
call(
"$ld:ai:generation:error",
context,
{"variationKey": "variation-key", "configKey": "config-key", "version": 3},
{"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"},
1,
),
]
Expand Down
8 changes: 8 additions & 0 deletions ldai/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
variation_key: str,
config_key: str,
version: int,
model_name: str,
provider_name: str,
context: Context,
):
"""
Expand All @@ -83,12 +85,16 @@ def __init__(
:param variation_key: Variation key for tracking.
:param config_key: Configuration key for tracking.
:param version: Version of the variation.
:param model_name: Name of the model used.
:param provider_name: Name of the provider used.
:param context: Context for evaluation.
"""
self._ld_client = ld_client
self._variation_key = variation_key
self._config_key = config_key
self._version = version
self._model_name = model_name
self._provider_name = provider_name
self._context = context
self._summary = LDAIMetricSummary()

Expand All @@ -102,6 +108,8 @@ def __get_track_data(self):
"variationKey": self._variation_key,
"configKey": self._config_key,
"version": self._version,
"modelName": self._model_name,
"providerName": self._provider_name,
}

def track_duration(self, duration: int) -> None:
Expand Down