From 84a22f1ab3e7e3e8ea49f2cc3c43f8a3bf3139cc Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 12:41:30 -0500 Subject: [PATCH 1/8] fix!: Rename prompt to messages --- ldai/client.py | 16 ++++----- ldai/testing/test_model_config.py | 56 +++++++++++++++---------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 1fc79d2..106c61d 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -74,11 +74,11 @@ def get_attribute(self, key: str) -> Any: class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]): + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]]): self.tracker = tracker self.enabled = enabled self.model = model - self.prompt = prompt + self.messages = messages class LDAIClient: @@ -110,18 +110,18 @@ def model_config( all_variables.update(variables) all_variables['ldctx'] = context - prompt = None - if 'prompt' in variation and isinstance(variation['prompt'], list) and all( - isinstance(entry, dict) for entry in variation['prompt'] + messages = None + if 'messages' in variation and isinstance(variation['messages'], list) and all( + isinstance(entry, dict) for entry in variation['messages'] ): - prompt = [ + messages = [ LDMessage( role=entry['role'], content=self.__interpolate_template( entry['content'], all_variables ), ) - for entry in variation['prompt'] + for entry in variation['messages'] ] model = None @@ -143,7 +143,7 @@ def model_config( ), enabled=bool(enabled), model=model, - prompt=prompt + messages=messages ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 9bc7199..77fd1c3 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -14,7 +14,7 @@ def td() -> TestData: .variations( { 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, }, "green", @@ -23,11 +23,11 @@ def td() -> TestData: ) td.update( - td.flag('multiple-prompt') + td.flag('multiple-messages') .variations( { 'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, - 'prompt': [ + 'messages': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, ], @@ -43,7 +43,7 @@ def td() -> TestData: .variations( { 'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], + 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } ) @@ -55,7 +55,7 @@ def td() -> TestData: .variations( { 'model': {'modelId': 'fakeModel', 'temperature': 0.1}, - 'prompt': [{'role': 'system', 'content': 'Hello, {{name}}!'}], + 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } ) @@ -127,15 +127,15 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): tracker=tracker, enabled=True, model=ModelConfig('fakeModel'), - prompt=[LDMessage(role='system', content='Hello, {{name}}!')], + messages=[LDMessage(role='system', content='Hello, {{name}}!')], ) variables = {'name': 'World'} config = ldai_client.model_config('model-config', context, default_value, variables) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, World!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' assert config.enabled is True assert config.model is not None @@ -146,13 +146,13 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.model_config('model-config', context, default_value, {}) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, !' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, !' assert config.enabled is True assert config.model is not None @@ -163,16 +163,16 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} config = ldai_client.model_config( 'ctx-interpolation', context, default_value, variables ) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, Sandy!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, Sandy!' assert config.enabled is True assert config.model is not None @@ -184,17 +184,17 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): def test_model_config_multiple(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} config = ldai_client.model_config( - 'multiple-prompt', context, default_value, variables + 'multiple-messages', context, default_value, variables ) - assert config.prompt is not None - assert len(config.prompt) > 0 - assert config.prompt[0].content == 'Hello, World!' - assert config.prompt[1].content == 'The day is, Monday!' + assert config.messages is not None + assert len(config.messages) > 0 + assert config.messages[0].content == 'Hello, World!' + assert config.messages[1].content == 'The day is, Monday!' assert config.enabled is True assert config.model is not None @@ -205,7 +205,7 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.model_config('off-config', context, default_value, {}) @@ -218,21 +218,21 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.model_config('initial-config-disabled', context, default_value, {}) assert config.enabled is False assert config.model is None - assert config.prompt is None + assert config.messages is None def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') - default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), prompt=[]) + default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) config = ldai_client.model_config('initial-config-enabled', context, default_value, {}) assert config.enabled is True assert config.model is None - assert config.prompt is None + assert config.messages is None From 2ce71fa0e57d87699be9d93a847efa1df29f672d Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 12:44:14 -0500 Subject: [PATCH 2/8] fix!: Rename `LDAIClient.model_config` to `LDAIClient.config` --- ldai/client.py | 4 ++-- ldai/testing/test_model_config.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 106c61d..e6b8d3b 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -87,7 +87,7 @@ class LDAIClient: def __init__(self, client: LDClient): self.client = client - def model_config( + def config( self, key: str, context: Context, @@ -95,7 +95,7 @@ def model_config( variables: Optional[Dict[str, Any]] = None, ) -> AIConfig: """ - Get the value of a model configuration asynchronously. + Get the value of a model configuration. :param key: The key of the model configuration. :param context: The context to evaluate the model configuration in. diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 77fd1c3..2e8bfc8 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -131,7 +131,7 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): ) variables = {'name': 'World'} - config = ldai_client.model_config('model-config', context, default_value, variables) + config = ldai_client.config('model-config', context, default_value, variables) assert config.messages is not None assert len(config.messages) > 0 @@ -148,7 +148,7 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('model-config', context, default_value, {}) + config = ldai_client.config('model-config', context, default_value, {}) assert config.messages is not None assert len(config.messages) > 0 @@ -166,7 +166,7 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World'} - config = ldai_client.model_config( + config = ldai_client.config( 'ctx-interpolation', context, default_value, variables ) @@ -187,7 +187,7 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) variables = {'name': 'World', 'day': 'Monday'} - config = ldai_client.model_config( + config = ldai_client.config( 'multiple-messages', context, default_value, variables ) @@ -207,7 +207,7 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('off-config', context, default_value, {}) + config = ldai_client.config('off-config', context, default_value, {}) assert config.model is not None assert config.enabled is False @@ -220,7 +220,7 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('initial-config-disabled', context, default_value, {}) + config = ldai_client.config('initial-config-disabled', context, default_value, {}) assert config.enabled is False assert config.model is None @@ -231,7 +231,7 @@ def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(tracker=tracker, enabled=False, model=ModelConfig('fake-model'), messages=[]) - config = ldai_client.model_config('initial-config-enabled', context, default_value, {}) + config = ldai_client.config('initial-config-enabled', context, default_value, {}) assert config.enabled is True assert config.model is None From 2f9fc61c6974a796747f301f4f3162c1096d6b2e Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 16:48:51 -0500 Subject: [PATCH 3/8] feat: Add support for provider config --- ldai/client.py | 27 +++++++++++++++++++++++++-- ldai/testing/test_model_config.py | 14 ++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index e6b8d3b..46b637d 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -73,12 +73,29 @@ def get_attribute(self, key: str) -> Any: return self._attributes.get(key) +class ProviderConfig: + """ + Configuration related to the provider. + """ + + def __init__(self, id: str): + self._id = id + + @property + def id(self) -> str: + """ + The ID of the provider. + """ + return self._id + + class AIConfig: - def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]]): + def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None): self.tracker = tracker self.enabled = enabled self.model = model self.messages = messages + self.provider = provider class LDAIClient: @@ -124,6 +141,11 @@ def config( for entry in variation['messages'] ] + provider_config = None + if 'provider' in variation and isinstance(variation['provider'], dict): + provider = variation['provider'] + provider_config = ProviderConfig(provider.get('id', '')) + model = None if 'model' in variation: model = ModelConfig( @@ -143,7 +165,8 @@ def config( ), enabled=bool(enabled), model=model, - messages=messages + messages=messages, + provider=provider_config, ) def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str: diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 2e8bfc8..30981fb 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -14,6 +14,7 @@ def td() -> TestData: .variations( { 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, + 'provider': {'id': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, }, @@ -161,6 +162,17 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.model.max_tokens == 4096 +def test_provider_config_handling(ldai_client: LDAIClient, tracker): + context = Context.builder('user-key').name("Sandy").build() + default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) + variables = {'name': 'World'} + + config = ldai_client.config('model-config', context, default_value, variables) + + assert config.provider is not None + assert config.provider.id == 'fakeProvider' + + def test_context_interpolation(ldai_client: LDAIClient, tracker): context = Context.builder('user-key').name("Sandy").build() default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[]) @@ -225,6 +237,7 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): assert config.enabled is False assert config.model is None assert config.messages is None + assert config.provider is None def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): @@ -236,3 +249,4 @@ def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker): assert config.enabled is True assert config.model is None assert config.messages is None + assert config.provider is None From 3941b1576e9dee080576180a7db3f9522754e396 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 16:49:57 -0500 Subject: [PATCH 4/8] fix!: Suffix track methods with metrics --- ldai/tracker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldai/tracker.py b/ldai/tracker.py index 93fec35..d179674 100644 --- a/ldai/tracker.py +++ b/ldai/tracker.py @@ -216,7 +216,7 @@ def track_success(self) -> None: '$ld:ai:generation', self.context, self.__get_track_data(), 1 ) - def track_openai(self, func): + def track_openai_metrics(self, func): """ Track OpenAI-specific operations. @@ -228,7 +228,7 @@ def track_openai(self, func): self.track_tokens(OpenAITokenUsage(result.usage)) return result - def track_bedrock_converse(self, res: dict) -> dict: + def track_bedrock_converse_metrics(self, res: dict) -> dict: """ Track AWS Bedrock conversation operations. From 5a70f80668c40a71ea9f2a618bf749b8e2121cf9 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 16:58:03 -0500 Subject: [PATCH 5/8] fix!: Remove max_tokens and temperature as top level model config keys --- ldai/client.py | 40 ++++++------------------------- ldai/testing/test_model_config.py | 35 ++++++++++++--------------- 2 files changed, 22 insertions(+), 53 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 46b637d..0f4741b 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -21,18 +21,13 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, id: str, temperature: Optional[float] = None, - max_tokens: Optional[int] = None, attributes: dict = {}): + def __init__(self, id: str, parameters: dict = {}): """ :param id: The ID of the model. - :param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model. - :param max_tokens: The maximum number of tokens. - :param attributes: Additional model-specific attributes. + :param parameters: Additional model-specific parameters. """ self._id = id - self._temperature = temperature - self._max_tokens = max_tokens - self._attributes = attributes + self._parameters = parameters @property def id(self) -> str: @@ -41,36 +36,17 @@ def id(self) -> str: """ return self._id - @property - def temperature(self) -> Optional[float]: - """" - Turning parameter for randomness versus determinism. Exact effect will be determined by the model. - """ - return self._temperature - - @property - def max_tokens(self) -> Optional[int]: - """ - The maximum number of tokens. - """ - - return self._max_tokens - - def get_attribute(self, key: str) -> Any: + def get_parameter(self, key: str) -> Any: """ - Retrieve model-specific attributes. + Retrieve model-specific parameters. Accessing a named, typed attribute (e.g. id) will result in the call being delegated to the appropriate property. """ if key == 'id': return self.id - if key == 'temperature': - return self.temperature - if key == 'maxTokens': - return self.max_tokens - return self._attributes.get(key) + return self._parameters.get(key) class ProviderConfig: @@ -150,9 +126,7 @@ def config( if 'model' in variation: model = ModelConfig( id=variation['model']['modelId'], - temperature=variation['model'].get('temperature'), - max_tokens=variation['model'].get('maxTokens'), - attributes=variation['model'], + parameters=variation['model'], ) enabled = variation.get('_ldMeta', {}).get('enabled', False) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 30981fb..ead401b 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -109,17 +109,12 @@ def ldai_client(client: LDClient) -> LDAIClient: def test_model_config_delegates_to_properties(): - model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'}) + model = ModelConfig('fakeModel', parameters={'extra-attribute': 'value'}) assert model.id == 'fakeModel' - assert model.temperature == 0.5 - assert model.max_tokens == 4096 - assert model.get_attribute('extra-attribute') == 'value' - assert model.get_attribute('non-existent') is None + assert model.get_parameter('extra-attribute') == 'value' + assert model.get_parameter('non-existent') is None - assert model.id == model.get_attribute('id') - assert model.temperature == model.get_attribute('temperature') - assert model.max_tokens == model.get_attribute('maxTokens') - assert model.max_tokens != model.get_attribute('max_tokens') + assert model.id == model.get_parameter('id') def test_model_config_interpolation(ldai_client: LDAIClient, tracker): @@ -141,8 +136,8 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker): assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.5 - assert config.model.max_tokens == 4096 + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 def test_model_config_no_variables(ldai_client: LDAIClient, tracker): @@ -158,8 +153,8 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker): assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.5 - assert config.model.max_tokens == 4096 + assert config.model.get_parameter('temperature') == 0.5 + assert config.model.get_parameter('maxTokens') == 4096 def test_provider_config_handling(ldai_client: LDAIClient, tracker): @@ -189,9 +184,9 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker): assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature is None - assert config.model.max_tokens is None - assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to' + assert config.model.get_parameter('temperature') is None + assert config.model.get_parameter('maxTokens') is None + assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to' def test_model_config_multiple(ldai_client: LDAIClient, tracker): @@ -211,8 +206,8 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker): assert config.model is not None assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.7 - assert config.model.max_tokens == 8192 + assert config.model.get_parameter('temperature') == 0.7 + assert config.model.get_parameter('maxTokens') == 8192 def test_model_config_disabled(ldai_client: LDAIClient, tracker): @@ -224,8 +219,8 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker): assert config.model is not None assert config.enabled is False assert config.model.id == 'fakeModel' - assert config.model.temperature == 0.1 - assert config.model.max_tokens is None + assert config.model.get_parameter('temperature') == 0.1 + assert config.model.get_parameter('maxTokens') is None def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker): From 19a0535d64040ec9c61957303b52d3b39404f687 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 16:59:34 -0500 Subject: [PATCH 6/8] fix!: Change modelId to id --- ldai/client.py | 2 +- ldai/testing/test_model_config.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 0f4741b..79377c5 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -125,7 +125,7 @@ def config( model = None if 'model' in variation: model = ModelConfig( - id=variation['model']['modelId'], + id=variation['model']['id'], parameters=variation['model'], ) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index ead401b..9c65e10 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -13,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, + 'model': {'id': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, 'provider': {'id': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, @@ -27,7 +27,7 @@ def td() -> TestData: td.flag('multiple-messages') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, + 'model': {'id': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, 'messages': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, @@ -43,7 +43,7 @@ def td() -> TestData: td.flag('ctx-interpolation') .variations( { - 'model': {'modelId': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, + 'model': {'id': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } @@ -55,7 +55,7 @@ def td() -> TestData: td.flag('off-config') .variations( { - 'model': {'modelId': 'fakeModel', 'temperature': 0.1}, + 'model': {'id': 'fakeModel', 'temperature': 0.1}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } From f3f67af84461df53bb1933430fc78edf8a72bf12 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 17:24:58 -0500 Subject: [PATCH 7/8] fix!: Rename nest parameters under model --- ldai/client.py | 10 +++++++--- ldai/testing/test_model_config.py | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index 79377c5..b3708da 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -21,7 +21,7 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, id: str, parameters: dict = {}): + def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None): """ :param id: The ID of the model. :param parameters: Additional model-specific parameters. @@ -46,6 +46,9 @@ def get_parameter(self, key: str) -> Any: if key == 'id': return self.id + if self._parameters is None: + return None + return self._parameters.get(key) @@ -123,10 +126,11 @@ def config( provider_config = ProviderConfig(provider.get('id', '')) model = None - if 'model' in variation: + if 'model' in variation and isinstance(variation['model'], dict): + parameters = variation['model'].get('parameters', None) model = ModelConfig( id=variation['model']['id'], - parameters=variation['model'], + parameters=parameters ) enabled = variation.get('_ldMeta', {}).get('enabled', False) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 9c65e10..8bf902a 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -13,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'id': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096}, + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}}, 'provider': {'id': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, @@ -27,7 +27,7 @@ def td() -> TestData: td.flag('multiple-messages') .variations( { - 'model': {'id': 'fakeModel', 'temperature': 0.7, 'maxTokens': 8192}, + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.7, 'maxTokens': 8192}}, 'messages': [ {'role': 'system', 'content': 'Hello, {{name}}!'}, {'role': 'user', 'content': 'The day is, {{day}}!'}, @@ -43,7 +43,7 @@ def td() -> TestData: td.flag('ctx-interpolation') .variations( { - 'model': {'id': 'fakeModel', 'extra-attribute': 'I can be anything I set my mind/type to'}, + 'model': {'id': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}}, 'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, } @@ -55,7 +55,7 @@ def td() -> TestData: td.flag('off-config') .variations( { - 'model': {'id': 'fakeModel', 'temperature': 0.1}, + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.1}}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': False, 'versionKey': 'abcd'}, } From b433f2fbda0327951810203a8690c94b0ea217fa Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 22 Nov 2024 17:34:05 -0500 Subject: [PATCH 8/8] feat: Add custom parameter support to model config --- ldai/client.py | 17 +++++++++++++++-- ldai/testing/test_model_config.py | 10 +++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/ldai/client.py b/ldai/client.py index b3708da..e2d53a2 100644 --- a/ldai/client.py +++ b/ldai/client.py @@ -21,13 +21,15 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None): + def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None): """ :param id: The ID of the model. :param parameters: Additional model-specific parameters. + :param custom: Additional customer provided data. """ self._id = id self._parameters = parameters + self._custom = custom @property def id(self) -> str: @@ -51,6 +53,15 @@ def get_parameter(self, key: str) -> Any: return self._parameters.get(key) + def get_custom(self, key: str) -> Any: + """ + Retrieve customer provided data. + """ + if self._custom is None: + return None + + return self._custom.get(key) + class ProviderConfig: """ @@ -128,9 +139,11 @@ def config( model = None if 'model' in variation and isinstance(variation['model'], dict): parameters = variation['model'].get('parameters', None) + custom = variation['model'].get('custom', None) model = ModelConfig( id=variation['model']['id'], - parameters=parameters + parameters=parameters, + custom=custom ) enabled = variation.get('_ldMeta', {}).get('enabled', False) diff --git a/ldai/testing/test_model_config.py b/ldai/testing/test_model_config.py index 8bf902a..6f97a4d 100644 --- a/ldai/testing/test_model_config.py +++ b/ldai/testing/test_model_config.py @@ -13,7 +13,7 @@ def td() -> TestData: td.flag('model-config') .variations( { - 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}}, + 'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}}, 'provider': {'id': 'fakeProvider'}, 'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}], '_ldMeta': {'enabled': True, 'versionKey': 'abcd'}, @@ -117,6 +117,14 @@ def test_model_config_delegates_to_properties(): assert model.id == model.get_parameter('id') +def test_model_config_handles_custom(): + model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'}) + assert model.id == 'fakeModel' + assert model.get_parameter('extra-attribute') is None + assert model.get_custom('non-existent') is None + assert model.get_custom('id') is None + + def test_model_config_interpolation(ldai_client: LDAIClient, tracker): context = Context.create('user-key') default_value = AIConfig(