From 642c6d617bc9d6ddfe443b17ce6e7a5eccbfa1bd Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 8 Nov 2024 10:11:37 -0500 Subject: [PATCH] style: Fix and enforce formatting --- Makefile | 2 + ldai/client.py | 49 +++++++++--- ldai/testing/test_model_config.py | 121 ++++++++++++++++++++++-------- ldai/tracker.py | 82 +++++++++++++++----- pyproject.toml | 2 + 5 files changed, 194 insertions(+), 62 deletions(-) diff --git a/Makefile b/Makefile index 2ed7ee8..791925c 100644 --- a/Makefile +++ b/Makefile @@ -30,6 +30,8 @@ test: install lint: #! Run type analysis and linting checks lint: install @poetry run mypy ldai + @poetry run isort --check --atomic ldai + @poetry run pycodestyle ldai # # Documentation generation diff --git a/ldai/client.py b/ldai/client.py index 0dcc6e8..0312096 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -1,33 +1,45 @@ +from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional + +import chevron from ldclient import Context from ldclient.client import LDClient -import chevron from ldai.tracker import LDAIConfigTracker -from dataclasses import dataclass + @dataclass -class LDMessage(): +class LDMessage: role: Literal['system', 'user', 'assistant'] content: str + @dataclass -class AIConfigData(): +class AIConfigData: model: Optional[dict] prompt: Optional[List[LDMessage]] -class AIConfig(): + + +class AIConfig: def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool): self.config = config self.tracker = tracker self.enabled = enabled + class LDAIClient: """The LaunchDarkly AI SDK client object.""" def __init__(self, client: LDClient): self.client = client - def model_config(self, key: str, context: Context, default_value: AIConfig, variables: Optional[Dict[str, Any]] = None) -> AIConfig: + def model_config( + self, + key: str, + context: Context, + default_value: AIConfig, + variables: Optional[Dict[str, Any]] = None, + ) -> AIConfig: """ Get the value of a model configuration asynchronously. @@ -43,18 +55,31 @@ def model_config(self, key: str, context: Context, default_value: AIConfig, vari if variables: all_variables.update(variables) all_variables['ldctx'] = context - - if isinstance(variation['prompt'], list) and all(isinstance(entry, dict) for entry in variation['prompt']): + + if isinstance(variation['prompt'], list) and all( + isinstance(entry, dict) for entry in variation['prompt'] + ): variation['prompt'] = [ LDMessage( role=entry['role'], - content=self.__interpolate_template(entry['content'], all_variables) + content=self.__interpolate_template( + entry['content'], all_variables + ), ) for entry in variation['prompt'] ] - enabled = variation.get('_ldMeta',{}).get('enabled', False) - return AIConfig(config=AIConfigData(model=variation['model'], prompt=variation['prompt']), tracker=LDAIConfigTracker(self.client, variation.get('_ldMeta', {}).get('versionKey', ''), key, context), enabled=bool(enabled)) + enabled = variation.get('_ldMeta', {}).get('enabled', False) + return AIConfig( + config=AIConfigData(model=variation['model'], prompt=variation['prompt']), + tracker=LDAIConfigTracker( + self.client, + variation.get('_ldMeta', {}).get('versionKey', ''), + key, + context, + ), + enabled=bool(enabled), + ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: """ @@ -64,4 +89,4 @@ def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> st :variables: The variables to interpolate into the template. :return: The interpolated string. """ - return chevron.render(template, variables) \ No newline at end of file + return chevron.render(template, variables) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 826f78e..59d846d 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -1,56 +1,97 @@ import pytest -from ldclient import LDClient, Context, Config +from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData +from ldclient.testing.builders import * + from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage from ldai.tracker import LDAIConfigTracker -from ldclient.testing.builders import * @pytest.fixture def td() -> TestData: td = TestData.data_source() - td.update(td.flag('model-config').variations({ - 'model': { 'modelId': 'fakeModel'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'} - }, "green").variation_for_all(0)) - - td.update(td.flag('multiple-prompt').variations({ - 'model': { 'modelId': 'fakeModel'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'} - }, "green").variation_for_all(0)) - - td.update(td.flag('ctx-interpolation').variations({ - 'model': { 'modelId': 'fakeModel'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], - '_ldMeta': {'enabled': True, 'versionKey': 'abcd'} - }).variation_for_all(0)) - - td.update(td.flag('off-config').variations({ - 'model': { 'modelId': 'fakeModel'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], - '_ldMeta': {'enabled': False, 'versionKey': 'abcd'} - }).variation_for_all(0)) + td.update( + td.flag('model-config') + .variations( + { + 'model': {'modelId': 'fakeModel'}, + 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + }, + "green", + ) + .variation_for_all(0) + ) + + td.update( + td.flag('multiple-prompt') + .variations( + { + 'model': {'modelId': 'fakeModel'}, + 'prompt': [ + {'role': 'system', 'content': 'Hello, {{name}}!'}, + {'role': 'user', 'content': 'The day is, {{day}}!'}, + ], + '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + }, + "green", + ) + .variation_for_all(0) + ) + + td.update( + td.flag('ctx-interpolation') + .variations( + { + 'model': {'modelId': 'fakeModel'}, + 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], + '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, + } + ) + .variation_for_all(0) + ) + + td.update( + td.flag('off-config') + .variations( + { + 'model': {'modelId': 'fakeModel'}, + 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, + } + ) + .variation_for_all(0) + ) return td + @pytest.fixture def client(td: TestData) -> LDClient: config = Config('sdk-key', update_processor_class=td, send_events=False) return LDClient(config=config) + @pytest.fixture def tracker(client: LDClient) -> LDAIConfigTracker: return LDAIConfigTracker(client, 'abcd', 'model-config', Context.create('user-key')) + @pytest.fixture def ldai_client(client: LDClient) -> LDAIClient: return LDAIClient(client) + def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(config=AIConfigData(model={ 'modelId': 'fakeModel'}, prompt=[LDMessage(role='system', content='Hello, {{name}}!')]), tracker=tracker, enabled=True) + default_value = AIConfig( + config=AIConfigData( + model={'modelId': 'fakeModel'}, + prompt=[LDMessage(role='system', content='Hello, {{name}}!')], + ), + tracker=tracker, + enabled=True, + ) variables = {'name': 'World'} config = ldai_client.model_config('model-config', context, default_value, variables) @@ -60,9 +101,12 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.config.prompt[0].content == 'Hello, World!' assert config.enabled is True + def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True) + default_value = AIConfig( + config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True + ) config = ldai_client.model_config('model-config', context, default_value, {}) @@ -71,24 +115,34 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.config.prompt[0].content == 'Hello, !' assert config.enabled is True + def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True) + default_value = AIConfig( + config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True + ) variables = {'name': 'World'} - config = ldai_client.model_config('ctx-interpolation', context, default_value, variables) + config = ldai_client.model_config( + 'ctx-interpolation', context, default_value, variables + ) assert config.config.prompt is not None assert len(config.config.prompt) > 0 assert config.config.prompt[0].content == 'Hello, Sandy!' assert config.enabled is True + def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True) + default_value = AIConfig( + config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True + ) variables = {'name': 'World', 'day': 'Monday'} - config = ldai_client.model_config('multiple-prompt', context, default_value, variables) + config = ldai_client.model_config( + 'multiple-prompt', context, default_value, variables + ) assert config.config.prompt is not None assert len(config.config.prompt) > 0 @@ -96,9 +150,12 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.config.prompt[1].content == 'The day is, Monday!' assert config.enabled is True + def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False) + default_value = AIConfig( + config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False + ) config = ldai_client.model_config('off-config', context, default_value, {}) diff --git a/ldai/tracker.py b/ldai/tracker.py index 1c3bbc7..93fec35 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -1,11 +1,13 @@ -from enum import Enum import time +from dataclasses import dataclass +from enum import Enum from typing import Dict, Union + from ldclient import Context, LDClient -from dataclasses import dataclass + @dataclass -class TokenMetrics(): +class TokenMetrics: """ Metrics for token usage in AI operations. @@ -13,20 +15,24 @@ class TokenMetrics(): :param input: Number of input tokens. :param output: Number of output tokens. """ + total: int input: int - output: int # type: ignore + output: int # type: ignore + @dataclass class FeedbackKind(Enum): """ Types of feedback that can be provided for AI operations. """ + Positive = "positive" Negative = "negative" + @dataclass -class TokenUsage(): +class TokenUsage: """ Tracks token usage for AI operations. @@ -34,6 +40,7 @@ class TokenUsage(): :param prompt_tokens: Number of tokens in the prompt. :param completion_tokens: Number of tokens in the completion. """ + total_tokens: int prompt_tokens: int completion_tokens: int @@ -50,8 +57,9 @@ def to_metrics(self): 'output': self['completion_tokens'], } -@dataclass -class LDOpenAIUsage(): + +@dataclass +class LDOpenAIUsage: """ LaunchDarkly-specific OpenAI usage tracking. @@ -59,15 +67,18 @@ class LDOpenAIUsage(): :param prompt_tokens: Number of tokens in the prompt. :param completion_tokens: Number of tokens in the completion. """ + total_tokens: int prompt_tokens: int completion_tokens: int + @dataclass class OpenAITokenUsage: """ Tracks OpenAI-specific token usage. """ + def __init__(self, data: LDOpenAIUsage): """ Initialize OpenAI token usage tracking. @@ -89,12 +100,14 @@ def to_metrics(self) -> TokenMetrics: input=self.prompt_tokens, output=self.completion_tokens, ) - + + @dataclass class BedrockTokenUsage: """ Tracks AWS Bedrock-specific token usage. """ + def __init__(self, data: dict): """ Initialize Bedrock token usage tracking. @@ -116,12 +129,16 @@ def to_metrics(self) -> TokenMetrics: input=self.inputTokens, output=self.outputTokens, ) - + + class LDAIConfigTracker: """ Tracks configuration and usage metrics for LaunchDarkly AI operations. """ - def __init__(self, ld_client: LDClient, version_key: str, config_key: str, context: Context): + + def __init__( + self, ld_client: LDClient, version_key: str, config_key: str, context: Context + ): """ Initialize an AI configuration tracker. @@ -145,14 +162,16 @@ def __get_track_data(self): 'versionKey': self.version_key, 'configKey': self.config_key, } - + def track_duration(self, duration: int) -> None: """ Manually track the duration of an AI operation. :param duration: Duration in milliseconds. """ - self.ld_client.track('$ld:ai:duration:total', self.context, self.__get_track_data(), duration) + self.ld_client.track( + '$ld:ai:duration:total', self.context, self.__get_track_data(), duration + ) def track_duration_of(self, func): """ @@ -175,15 +194,27 @@ def track_feedback(self, feedback: Dict[str, FeedbackKind]) -> None: :param feedback: Dictionary containing feedback kind. """ if feedback['kind'] == FeedbackKind.Positive: - self.ld_client.track('$ld:ai:feedback:user:positive', self.context, self.__get_track_data(), 1) + self.ld_client.track( + '$ld:ai:feedback:user:positive', + self.context, + self.__get_track_data(), + 1, + ) elif feedback['kind'] == FeedbackKind.Negative: - self.ld_client.track('$ld:ai:feedback:user:negative', self.context, self.__get_track_data(), 1) + self.ld_client.track( + '$ld:ai:feedback:user:negative', + self.context, + self.__get_track_data(), + 1, + ) def track_success(self) -> None: """ Track a successful AI generation. """ - self.ld_client.track('$ld:ai:generation', self.context, self.__get_track_data(), 1) + self.ld_client.track( + '$ld:ai:generation', self.context, self.__get_track_data(), 1 + ) def track_openai(self, func): """ @@ -224,8 +255,23 @@ def track_tokens(self, tokens: Union[TokenUsage, BedrockTokenUsage]) -> None: """ token_metrics = tokens.to_metrics() if token_metrics.total > 0: - self.ld_client.track('$ld:ai:tokens:total', self.context, self.__get_track_data(), token_metrics.total) + self.ld_client.track( + '$ld:ai:tokens:total', + self.context, + self.__get_track_data(), + token_metrics.total, + ) if token_metrics.input > 0: - self.ld_client.track('$ld:ai:tokens:input', self.context, self.__get_track_data(), token_metrics.input) + self.ld_client.track( + '$ld:ai:tokens:input', + self.context, + self.__get_track_data(), + token_metrics.input, + ) if token_metrics.output > 0: - self.ld_client.track('$ld:ai:tokens:output', self.context, self.__get_track_data(), token_metrics.output) + self.ld_client.track( + '$ld:ai:tokens:output', + self.context, + self.__get_track_data(), + token_metrics.output, + ) diff --git a/pyproject.toml b/pyproject.toml index 5829046..c03027d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,8 @@ exclude = [ python = ">=3.8" launchdarkly-server-sdk = ">=9.4.0" chevron = "=0.14.0" +pycodestyle = "^2.12.1" +isort = "^5.13.2" [tool.poetry.group.dev.dependencies]