diff --git a/ldai/client.py b/ldai/client.py index 90218e2..341c784 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -16,17 +16,69 @@ class LDMessage: content: str -@dataclass -class AIConfigData: - model: Optional[dict] - prompt: Optional[List[LDMessage]] +class ModelConfig: + """ + Configuration related to the model. + """ + + def __init__(self, id: str, temperature: Optional[float] = None, + max_tokens: Optional[int] = None, attributes: dict = {}): + """ + :param id: The ID of the model. + :param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model. + :param max_tokens: The maximum number of tokens. + :param attributes: Additional model-specific attributes. + """ + self._id = id + self._temperature = temperature + self._max_tokens = max_tokens + self._attributes = attributes + + @property + def id(self) -> str: + """ + The ID of the model. + """ + return self._id + + @property + def temperature(self) -> Optional[float]: + """" + Turning parameter for randomness versus determinism. Exact effect will be determined by the model. + """ + return self._temperature + + @property + def max_tokens(self) -> Optional[int]: + """ + The maximum number of tokens. + """ + + return self._max_tokens + + def get_attribute(self, key: str) -> Any: + """ + Retrieve model-specific attributes. + + Accessing a named, typed attribute (e.g. id) will result in the call + being delegated to the appropriate property. + """ + if key == 'id': + return self.id + if key == 'temperature': + return self.temperature + if key == 'maxTokens': + return self.max_tokens + + return self._attributes.get(key) class AIConfig: - def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool): - self.config = config + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]): self.tracker = tracker self.enabled = enabled + self.model = model + self.prompt = prompt class LDAIClient: @@ -71,9 +123,17 @@ def model_config( for entry in variation['prompt'] ] + model = None + if 'model' in variation: + model = ModelConfig( + id=variation['model']['modelId'], + temperature=variation['model'].get('temperature'), + max_tokens=variation['model'].get('maxTokens'), + attributes=variation['model'], + ) + enabled = variation.get('_ldMeta', {}).get('enabled', False) return AIConfig( - config=AIConfigData(model=variation['model'], prompt=prompt), tracker=LDAIConfigTracker( self.client, variation.get('_ldMeta', {}).get('versionKey', ''), @@ -81,6 +141,8 @@ def model_config( context, ), enabled=bool(enabled), + model=model, + prompt=prompt ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 59d846d..b47331f 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -1,9 +1,8 @@ import pytest from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldclient.testing.builders import * -from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage +from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig from ldai.tracker import LDAIConfigTracker @@ -14,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, }, @@ -27,7 +26,7 @@ def td() -> TestData: td.flag('multiple-prompt') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, 'prompt': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, @@ -43,7 +42,7 @@ def td() -> TestData: td.flag('ctx-interpolation') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } @@ -55,7 +54,7 @@ def td() -> TestData: td.flag('off-config') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.1}, 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } @@ -82,81 +81,110 @@ def ldai_client(client: LDClient) -> LDAIClient: return LDAIClient(client) +def test_model_config_delegates_to_properties(): + model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'}) + assert model.id == 'fakeModel' + assert model.temperature == 0.5 + assert model.max_tokens == 4096 + assert model.get_attribute('extra-attribute') == 'value' + assert model.get_attribute('non-existent') is None + + assert model.id == model.get_attribute('id') + assert model.temperature == model.get_attribute('temperature') + assert model.max_tokens == model.get_attribute('maxTokens') + assert model.max_tokens != model.get_attribute('max_tokens') + + def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig( - config=AIConfigData( - model={'modelId': 'fakeModel'}, - prompt=[LDMessage(role='system', content='Hello, {{name}}!')], - ), tracker=tracker, enabled=True, + model=ModelConfig('fakeModel'), + prompt=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} config = ldai_client.model_config('model-config', context, default_value, variables) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, World!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, World!' assert config.enabled is True + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.5 + assert config.model.max_tokens == 4096 + def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) config = ldai_client.model_config('model-config', context, default_value, {}) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, !' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, !' assert config.enabled is True + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.5 + assert config.model.max_tokens == 4096 + def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) variables = {'name': 'World'} config = ldai_client.model_config( 'ctx-interpolation', context, default_value, variables ) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, Sandy!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, Sandy!' assert config.enabled is True + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.temperature is None + assert config.model.max_tokens is None + assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to' + def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) variables = {'name': 'World', 'day': 'Monday'} config = ldai_client.model_config( 'multiple-prompt', context, default_value, variables ) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, World!' - assert config.config.prompt[1].content == 'The day is, Monday!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, World!' + assert config.prompt[1].content == 'The day is, Monday!' assert config.enabled is True + assert config.model is not None + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.7 + assert config.model.max_tokens == 8192 + def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False - ) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) config = ldai_client.model_config('off-config', context, default_value, {}) + assert config.model is not None assert config.enabled is False + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.1 + assert config.model.max_tokens is None