diff --git a/ldai/client.py b/ldai/client.py index 1fc79d2..e2d53a2 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -21,18 +21,15 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, id: str, temperature: Optional[float] = None, - max_tokens: Optional[int] = None, attributes: dict = {}): + def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None): """ :param id: The ID of the model. - :param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model. - :param max_tokens: The maximum number of tokens. - :param attributes: Additional model-specific attributes. + :param parameters: Additional model-specific parameters. + :param custom: Additional customer provided data. """ self._id = id - self._temperature = temperature - self._max_tokens = max_tokens - self._attributes = attributes + self._parameters = parameters + self._custom = custom @property def id(self) -> str: @@ -41,44 +38,54 @@ def id(self) -> str: """ return self._id - @property - def temperature(self) -> Optional[float]: - """" - Turning parameter for randomness versus determinism. Exact effect will be determined by the model. + def get_parameter(self, key: str) -> Any: """ - return self._temperature - - @property - def max_tokens(self) -> Optional[int]: - """ - The maximum number of tokens. - """ - - return self._max_tokens - - def get_attribute(self, key: str) -> Any: - """ - Retrieve model-specific attributes. + Retrieve model-specific parameters. Accessing a named, typed attribute (e.g. id) will result in the call being delegated to the appropriate property. """ if key == 'id': return self.id - if key == 'temperature': - return self.temperature - if key == 'maxTokens': - return self.max_tokens - return self._attributes.get(key) + if self._parameters is None: + return None + + return self._parameters.get(key) + + def get_custom(self, key: str) -> Any: + """ + Retrieve customer provided data. + """ + if self._custom is None: + return None + + return self._custom.get(key) + + +class ProviderConfig: + """ + Configuration related to the provider. + """ + + def __init__(self, id: str): + self._id = id + + @property + def id(self) -> str: + """ + The ID of the provider. + """ + return self._id class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]): + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None): self.tracker = tracker self.enabled = enabled self.model = model - self.prompt = prompt + self.messages = messages + self.provider = provider class LDAIClient: @@ -87,7 +94,7 @@ class LDAIClient: def __init__(self, client: LDClient): self.client = client - def model_config( + def config( self, key: str, context: Context, @@ -95,7 +102,7 @@ def model_config( variables: Optional[Dict[str, Any]] = None, ) -> AIConfig: """ - Get the value of a model configuration asynchronously. + Get the value of a model configuration. :param key: The key of the model configuration. :param context: The context to evaluate the model configuration in. @@ -110,27 +117,33 @@ def model_config( all_variables.update(variables) all_variables['ldctx'] = context - prompt = None - if 'prompt' in variation and isinstance(variation['prompt'], list) and all( - isinstance(entry, dict) for entry in variation['prompt'] + messages = None + if 'messages' in variation and isinstance(variation['messages'], list) and all( + isinstance(entry, dict) for entry in variation['messages'] ): - prompt = [ + messages = [ LDMessage( role=entry['role'], content=self.__interpolate_template( entry['content'], all_variables ), ) - for entry in variation['prompt'] + for entry in variation['messages'] ] + provider_config = None + if 'provider' in variation and isinstance(variation['provider'], dict): + provider = variation['provider'] + provider_config = ProviderConfig(provider.get('id', '')) + model = None - if 'model' in variation: + if 'model' in variation and isinstance(variation['model'], dict): + parameters = variation['model'].get('parameters', None) + custom = variation['model'].get('custom', None) model = ModelConfig( - id=variation['model']['modelId'], - temperature=variation['model'].get('temperature'), - max_tokens=variation['model'].get('maxTokens'), - attributes=variation['model'], + id=variation['model']['id'], + parameters=parameters, + custom=custom ) enabled = variation.get('_ldMeta', {}).get('enabled', False) @@ -143,7 +156,8 @@ def model_config( ), enabled=bool(enabled), model=model, - prompt=prompt + messages=messages, + provider=provider_config, ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 9bc7199..6f97a4d 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -13,8 +13,9 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}}, + 'provider': {'id': 'fakeProvider'}, + 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, }, "green", @@ -23,11 +24,11 @@ def td() -> TestData: ) td.update( - td.flag('multiple-prompt') + td.flag('multiple-messages') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, - 'prompt': [ + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.7, 'maxTokens': 8192}}, + 'messages': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, ], @@ -42,8 +43,8 @@ def td() -> TestData: td.flag('ctx-interpolation') .variations( { - 'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], + 'model': {'id': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}}, + 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } ) @@ -54,8 +55,8 @@ def td() -> TestData: td.flag('off-config') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.1}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.1}}, + 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } ) @@ -108,17 +109,20 @@ def ldai_client(client: LDClient) -> LDAIClient: def test_model_config_delegates_to_properties(): - model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'}) + model = ModelConfig('fakeModel', parameters={'extra-attribute': 'value'}) assert model.id == 'fakeModel' - assert model.temperature == 0.5 - assert model.max_tokens == 4096 - assert model.get_attribute('extra-attribute') == 'value' - assert model.get_attribute('non-existent') is None + assert model.get_parameter('extra-attribute') == 'value' + assert model.get_parameter('non-existent') is None - assert model.id == model.get_attribute('id') - assert model.temperature == model.get_attribute('temperature') - assert model.max_tokens == model.get_attribute('maxTokens') - assert model.max_tokens != model.get_attribute('max_tokens') + assert model.id == model.get_parameter('id') + + +def test_model_config_handles_custom(): + model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'}) + assert model.id == 'fakeModel' + assert model.get_parameter('extra-attribute') is None + assert model.get_custom('non-existent') is None + assert model.get_custom('id') is None def test_model_config_interpolation(ldai_client: LDAIClient, tracker): @@ -127,112 +131,125 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): tracker=tracker, enabled=True, model=ModelConfig('fakeModel'), - prompt=[LDMessage(role='system', content='Hello, {{name}}!')], + messages=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} - config = ldai_client.model_config('model-config', context, default_value, variables) + config = ldai_client.config('model-config', context, default_value, variables) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, World!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' assert config.enabled is True assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.5 - assert config.model.max_tokens == 4096 + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('model-config', context, default_value, {}) + config = ldai_client.config('model-config', context, default_value, {}) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, !' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, !' assert config.enabled is True assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.5 - assert config.model.max_tokens == 4096 + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 + + +def test_provider_config_handling(ldai_client: LDAIClient, tracker): + context = Context.builder('user-key').name("Sandy").build() + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + variables = {'name': 'World'} + + config = ldai_client.config('model-config', context, default_value, variables) + + assert config.provider is not None + assert config.provider.id == 'fakeProvider' def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.model_config( + config = ldai_client.config( 'ctx-interpolation', context, default_value, variables ) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, Sandy!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, Sandy!' assert config.enabled is True assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature is None - assert config.model.max_tokens is None - assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to' + assert config.model.get_parameter('temperature') is None + assert config.model.get_parameter('maxTokens') is None + assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to' def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} - config = ldai_client.model_config( - 'multiple-prompt', context, default_value, variables + config = ldai_client.config( + 'multiple-messages', context, default_value, variables ) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, World!' - assert config.prompt[1].content == 'The day is, Monday!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' + assert config.messages[1].content == 'The day is, Monday!' assert config.enabled is True assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.7 - assert config.model.max_tokens == 8192 + assert config.model.get_parameter('temperature') == 0.7 + assert config.model.get_parameter('maxTokens') == 8192 def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('off-config', context, default_value, {}) + config = ldai_client.config('off-config', context, default_value, {}) assert config.model is not None assert config.enabled is False assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.1 - assert config.model.max_tokens is None + assert config.model.get_parameter('temperature') == 0.1 + assert config.model.get_parameter('maxTokens') is None def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('initial-config-disabled', context, default_value, {}) + config = ldai_client.config('initial-config-disabled', context, default_value, {}) assert config.enabled is False assert config.model is None - assert config.prompt is None + assert config.messages is None + assert config.provider is None def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('initial-config-enabled', context, default_value, {}) + config = ldai_client.config('initial-config-enabled', context, default_value, {}) assert config.enabled is True assert config.model is None - assert config.prompt is None + assert config.messages is None + assert config.provider is None diff --git a/ldai/tracker.py b/ldai/tracker.py index 93fec35..d179674 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -216,7 +216,7 @@ def track_success(self) -> None: '$ld:ai:generation', self.context, self.__get_track_data(), 1 ) - def track_openai(self, func): + def track_openai_metrics(self, func): """ Track OpenAI-specific operations. @@ -228,7 +228,7 @@ def track_openai(self, func): self.track_tokens(OpenAITokenUsage(result.usage)) return result - def track_bedrock_converse(self, res: dict) -> dict: + def track_bedrock_converse_metrics(self, res: dict) -> dict: """ Track AWS Bedrock conversation operations.