diff --git a/ldai/client.py b/ldai/client.py index 8854b4b..261012c 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -402,6 +402,8 @@ def __evaluate( variation.get('_ldMeta', {}).get('variationKey', ''), key, int(variation.get('_ldMeta', {}).get('version', 1)), + model.name if model else '', + provider_config.name if provider_config else '', context, ) diff --git a/ldai/testing/test_tracker.py b/ldai/testing/test_tracker.py index 30a20bd..19c8161 100644 --- a/ldai/testing/test_tracker.py +++ b/ldai/testing/test_tracker.py @@ -42,7 +42,7 @@ def client(td: TestData) -> LDClient: def test_summary_starts_empty(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 1, "fakeModel", "fakeProvider", context) assert tracker.get_summary().duration is None assert tracker.get_summary().feedback is None @@ -52,13 +52,13 @@ def test_summary_starts_empty(client: LDClient): def test_tracks_duration(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_duration(100) client.track.assert_called_with( # type: ignore "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ) @@ -67,7 +67,7 @@ def test_tracks_duration(client: LDClient): def test_tracks_duration_of(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_duration_of(lambda: sleep(0.01)) calls = client.track.mock_calls # type: ignore @@ -79,19 +79,21 @@ def test_tracks_duration_of(client: LDClient): "variationKey": "variation-key", "configKey": "config-key", "version": 3, + "modelName": "fakeModel", + "providerName": "fakeProvider", } assert calls[0].args[3] == pytest.approx(10, rel=10) def test_tracks_time_to_first_token(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_time_to_first_token(100) client.track.assert_called_with( # type: ignore "$ld:ai:tokens:ttf", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ) @@ -100,7 +102,7 @@ def test_tracks_time_to_first_token(client: LDClient): def test_tracks_duration_of_with_exception(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) def sleep_and_throw(): sleep(0.01) @@ -121,13 +123,15 @@ def sleep_and_throw(): "variationKey": "variation-key", "configKey": "config-key", "version": 3, + "modelName": "fakeModel", + "providerName": "fakeProvider", } assert calls[0].args[3] == pytest.approx(10, rel=10) def test_tracks_token_usage(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tokens = TokenUsage(300, 200, 100) tracker.track_tokens(tokens) @@ -136,19 +140,19 @@ def test_tracks_token_usage(client: LDClient): call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 300, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 200, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 100, ), ] @@ -160,7 +164,7 @@ def test_tracks_token_usage(client: LDClient): def test_tracks_bedrock_metrics(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) bedrock_result = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -179,31 +183,31 @@ def test_tracks_bedrock_metrics(client: LDClient): call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 50, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -217,7 +221,7 @@ def test_tracks_bedrock_metrics(client: LDClient): def test_tracks_bedrock_metrics_with_error(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) bedrock_result = { "ResponseMetadata": {"HTTPStatusCode": 500}, @@ -236,31 +240,31 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient): call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:duration:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 50, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -274,7 +278,7 @@ def test_tracks_bedrock_metrics_with_error(client: LDClient): def test_tracks_openai_metrics(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) class Result: def __init__(self): @@ -294,25 +298,25 @@ def to_dict(self): call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:tokens:total", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 330, ), call( "$ld:ai:tokens:input", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 220, ), call( "$ld:ai:tokens:output", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 110, ), ] @@ -324,7 +328,7 @@ def to_dict(self): def test_tracks_openai_metrics_with_exception(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) def raise_exception(): raise ValueError("Something went wrong") @@ -339,7 +343,7 @@ def raise_exception(): call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -358,14 +362,14 @@ def raise_exception(): ) def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_feedback({"kind": kind}) client.track.assert_called_with( # type: ignore f"$ld:ai:feedback:user:{label}", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ) assert tracker.get_summary().feedback == {"kind": kind} @@ -373,14 +377,14 @@ def test_tracks_feedback(client: LDClient, kind: FeedbackKind, label: str): def test_tracks_success(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_success() calls = [ call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -392,14 +396,14 @@ def test_tracks_success(client: LDClient): def test_tracks_error(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_error() calls = [ call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] @@ -411,7 +415,7 @@ def test_tracks_error(client: LDClient): def test_error_overwrites_success(client: LDClient): context = Context.create("user-key") - tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, context) + tracker = LDAIConfigTracker(client, "variation-key", "config-key", 3, "fakeModel", "fakeProvider", context) tracker.track_success() tracker.track_error() @@ -419,13 +423,13 @@ def test_error_overwrites_success(client: LDClient): call( "$ld:ai:generation:success", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), call( "$ld:ai:generation:error", context, - {"variationKey": "variation-key", "configKey": "config-key", "version": 3}, + {"variationKey": "variation-key", "configKey": "config-key", "version": 3, "modelName": "fakeModel", "providerName": "fakeProvider"}, 1, ), ] diff --git a/ldai/tracker.py b/ldai/tracker.py index 9099833..a049952 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -74,6 +74,8 @@ def __init__( variation_key: str, config_key: str, version: int, + model_name: str, + provider_name: str, context: Context, ): """ @@ -83,12 +85,16 @@ def __init__( :param variation_key: Variation key for tracking. :param config_key: Configuration key for tracking. :param version: Version of the variation. + :param model_name: Name of the model used. + :param provider_name: Name of the provider used. :param context: Context for evaluation. """ self._ld_client = ld_client self._variation_key = variation_key self._config_key = config_key self._version = version + self._model_name = model_name + self._provider_name = provider_name self._context = context self._summary = LDAIMetricSummary() @@ -102,6 +108,8 @@ def __get_track_data(self): "variationKey": self._variation_key, "configKey": self._config_key, "version": self._version, + "modelName": self._model_name, + "providerName": self._provider_name, } def track_duration(self, duration: int) -> None: