From dcbd88f749002132352d99fe39ce545da64f0f9a Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 15 Nov 2024 11:48:54 -0500 Subject: [PATCH 1/4] Inline generation config into top level config --- ldai/client.py | 14 ++++----- ldai/testing/test_model_config.py | 50 +++++++++++++------------------ 2 files changed, 25 insertions(+), 39 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 90218e2..2cc03a7 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -16,17 +16,12 @@ class LDMessage: content: str -@dataclass -class AIConfigData: - model: Optional[dict] - prompt: Optional[List[LDMessage]] - - class AIConfig: - def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool): - self.config = config + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[dict], prompt: Optional[List[LDMessage]]): self.tracker = tracker self.enabled = enabled + self.model = model + self.prompt = prompt class LDAIClient: @@ -73,7 +68,6 @@ def model_config( enabled = variation.get('_ldMeta', {}).get('enabled', False) return AIConfig( - config=AIConfigData(model=variation['model'], prompt=prompt), tracker=LDAIConfigTracker( self.client, variation.get('_ldMeta', {}).get('versionKey', ''), @@ -81,6 +75,8 @@ def model_config( context, ), enabled=bool(enabled), + model=variation['model'], + prompt=prompt ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 59d846d..df80790 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -3,7 +3,7 @@ from ldclient.integrations.test_data import TestData from ldclient.testing.builders import * -from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage +from ldai.client import AIConfig, LDAIClient, LDMessage from ldai.tracker import LDAIConfigTracker @@ -85,77 +85,67 @@ def ldai_client(client: LDClient) -> LDAIClient: def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig( - config=AIConfigData( - model={'modelId': 'fakeModel'}, - prompt=[LDMessage(role='system', content='Hello, {{name}}!')], - ), tracker=tracker, enabled=True, + model={'modelId': 'fakeModel'}, + prompt=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} config = ldai_client.model_config('model-config', context, default_value, variables) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, World!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, World!' assert config.enabled is True def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) config = ldai_client.model_config('model-config', context, default_value, {}) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, !' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, !' assert config.enabled is True def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) variables = {'name': 'World'} config = ldai_client.model_config( 'ctx-interpolation', context, default_value, variables ) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, Sandy!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, Sandy!' assert config.enabled is True def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True - ) + default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) variables = {'name': 'World', 'day': 'Monday'} config = ldai_client.model_config( 'multiple-prompt', context, default_value, variables ) - assert config.config.prompt is not None - assert len(config.config.prompt) > 0 - assert config.config.prompt[0].content == 'Hello, World!' - assert config.config.prompt[1].content == 'The day is, Monday!' + assert config.prompt is not None + assert len(config.prompt) > 0 + assert config.prompt[0].content == 'Hello, World!' + assert config.prompt[1].content == 'The day is, Monday!' assert config.enabled is True def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig( - config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False - ) + default_value = AIConfig(tracker=tracker, enabled=False, model={}, prompt=[]) config = ldai_client.model_config('off-config', context, default_value, {}) From 6b4f8c1420158c815884d54ab324e7a509b797cd Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 15 Nov 2024 13:22:04 -0500 Subject: [PATCH 2/4] Expand tests and assertions for new model config --- ldai/client.py | 70 ++++++++++++++++++++++++++++++- ldai/testing/test_model_config.py | 55 +++++++++++++++++++----- 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 2cc03a7..1939b16 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -16,8 +16,65 @@ class LDMessage: content: str +class ModelConfig: + """ + Configuration related to the model. + """ + + def __init__(self, id: str, temperature: Optional[float] = None, + max_tokens: Optional[int] = None, attributes: dict = {}): + """ + :param id: The ID of the model. + :param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model. + :param max_tokens: The maximum number of tokens. + :param attributes: Additional model-specific attributes. + """ + self._id = id + self._temperature = temperature + self._max_tokens = max_tokens + self._attributes = attributes + + @property + def id(self): + """ + The ID of the model. + """ + return self._id + + @property + def temperature(self): + """" + Turning parameter for randomness versus determinism. Exact effect will be determined by the model. + """ + return self._temperature + + @property + def max_tokens(self): + """ + The maximum number of tokens. + """ + + return self._max_tokens + + def get_attribute(self, key: str): + """ + Retrieve model-specific attributes. + + Accessing a named, typed attribute (e.g. id) will result in the call + being delegated to the appropriate property. + """ + if key == 'id': + return self.id + if key == 'temperature': + return self.temperature + if key == 'maxTokens': + return self.max_tokens + + return self._attributes.get(key) + + class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[dict], prompt: Optional[List[LDMessage]]): + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]): self.tracker = tracker self.enabled = enabled self.model = model @@ -66,6 +123,15 @@ def model_config( for entry in variation['prompt'] ] + model = None + if 'model' in variation: + model = ModelConfig( + id=variation['model']['modelId'], + temperature=variation['model'].get('temperature'), + max_tokens=variation['model'].get('maxTokens'), + attributes=variation['model'], + ) + enabled = variation.get('_ldMeta', {}).get('enabled', False) return AIConfig( tracker=LDAIConfigTracker( @@ -75,7 +141,7 @@ def model_config( context, ), enabled=bool(enabled), - model=variation['model'], + model=model, prompt=prompt ) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index df80790..75068fc 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -1,9 +1,8 @@ import pytest from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldclient.testing.builders import * -from ldai.client import AIConfig, LDAIClient, LDMessage +from ldai.client import AIConfig, ModelConfig, LDAIClient, LDMessage from ldai.tracker import LDAIConfigTracker @@ -14,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, }, @@ -27,7 +26,7 @@ def td() -> TestData: td.flag('multiple-prompt') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, 'prompt': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, @@ -43,7 +42,7 @@ def td() -> TestData: td.flag('ctx-interpolation') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } @@ -55,7 +54,7 @@ def td() -> TestData: td.flag('off-config') .variations( { - 'model': {'modelId': 'fakeModel'}, + 'model': {'modelId': 'fakeModel', 'temperature': 0.1}, 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } @@ -82,12 +81,26 @@ def ldai_client(client: LDClient) -> LDAIClient: return LDAIClient(client) +def test_model_config_delegates_to_properties(): + model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'}) + assert model.id == 'fakeModel' + assert model.temperature == 0.5 + assert model.max_tokens == 4096 + assert model.get_attribute('extra-attribute') == 'value' + assert model.get_attribute('non-existent') is None + + assert model.id == model.get_attribute('id') + assert model.temperature == model.get_attribute('temperature') + assert model.max_tokens == model.get_attribute('maxTokens') + assert model.max_tokens != model.get_attribute('max_tokens') + + def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig( tracker=tracker, enabled=True, - model={'modelId': 'fakeModel'}, + model=ModelConfig('fakeModel'), prompt=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} @@ -99,10 +112,14 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, World!' assert config.enabled is True + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.5 + assert config.model.max_tokens == 4096 + def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) config = ldai_client.model_config('model-config', context, default_value, {}) @@ -111,10 +128,14 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, !' assert config.enabled is True + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.5 + assert config.model.max_tokens == 4096 + def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) variables = {'name': 'World'} config = ldai_client.model_config( @@ -126,10 +147,15 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, Sandy!' assert config.enabled is True + assert config.model.id == 'fakeModel' + assert config.model.temperature is None + assert config.model.max_tokens is None + assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to' + def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) variables = {'name': 'World', 'day': 'Monday'} config = ldai_client.model_config( @@ -142,11 +168,18 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.prompt[1].content == 'The day is, Monday!' assert config.enabled is True + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.7 + assert config.model.max_tokens == 8192 + def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model={}, prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) config = ldai_client.model_config('off-config', context, default_value, {}) assert config.enabled is False + assert config.model.id == 'fakeModel' + assert config.model.temperature == 0.1 + assert config.model.max_tokens is None From d543fc44b0035a14aec94df57b014debfbf972f5 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 15 Nov 2024 13:30:22 -0500 Subject: [PATCH 3/4] Fix linting --- ldai/testing/test_model_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 75068fc..b47331f 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -2,7 +2,7 @@ from ldclient import Config, Context, LDClient from ldclient.integrations.test_data import TestData -from ldai.client import AIConfig, ModelConfig, LDAIClient, LDMessage +from ldai.client import AIConfig, LDAIClient, LDMessage, ModelConfig from ldai.tracker import LDAIConfigTracker @@ -112,6 +112,7 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, World!' assert config.enabled is True + assert config.model is not None assert config.model.id == 'fakeModel' assert config.model.temperature == 0.5 assert config.model.max_tokens == 4096 @@ -128,6 +129,7 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, !' assert config.enabled is True + assert config.model is not None assert config.model.id == 'fakeModel' assert config.model.temperature == 0.5 assert config.model.max_tokens == 4096 @@ -147,6 +149,7 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): assert config.prompt[0].content == 'Hello, Sandy!' assert config.enabled is True + assert config.model is not None assert config.model.id == 'fakeModel' assert config.model.temperature is None assert config.model.max_tokens is None @@ -168,6 +171,7 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.prompt[1].content == 'The day is, Monday!' assert config.enabled is True + assert config.model is not None assert config.model.id == 'fakeModel' assert config.model.temperature == 0.7 assert config.model.max_tokens == 8192 @@ -179,6 +183,7 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): config = ldai_client.model_config('off-config', context, default_value, {}) + assert config.model is not None assert config.enabled is False assert config.model.id == 'fakeModel' assert config.model.temperature == 0.1 From f68ff7f07fbb822bbf2051ccf680ce118809a963 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 15 Nov 2024 16:04:41 -0500 Subject: [PATCH 4/4] Add optional return types --- ldai/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 1939b16..341c784 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -35,28 +35,28 @@ def __init__(self, id: str, temperature: Optional[float] = None, self._attributes = attributes @property - def id(self): + def id(self) -> str: """ The ID of the model. """ return self._id @property - def temperature(self): + def temperature(self) -> Optional[float]: """" Turning parameter for randomness versus determinism. Exact effect will be determined by the model. """ return self._temperature @property - def max_tokens(self): + def max_tokens(self) -> Optional[int]: """ The maximum number of tokens. """ return self._max_tokens - def get_attribute(self, key: str): + def get_attribute(self, key: str) -> Any: """ Retrieve model-specific attributes.