Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,32 +28,32 @@ class ModelConfig:
Configuration related to the model.
"""

def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
"""
:param id: The ID of the model.
:param name: The name of the model.
:param parameters: Additional model-specific parameters.
:param custom: Additional customer provided data.
"""
self._id = id
self._name = name
self._parameters = parameters
self._custom = custom

@property
def id(self) -> str:
def name(self) -> str:
"""
The ID of the model.
The name of the model.
"""
return self._id
return self._name

def get_parameter(self, key: str) -> Any:
"""
Retrieve model-specific parameters.

Accessing a named, typed attribute (e.g. id) will result in the call
Accessing a named, typed attribute (e.g. name) will result in the call
being delegated to the appropriate property.
"""
if key == 'id':
return self.id
if key == 'name':
return self.name

if self._parameters is None:
return None
Expand All @@ -74,7 +74,7 @@ def to_dict(self) -> dict:
Render the given model config as a dictionary object.
"""
return {
'id': self._id,
'name': self._name,
'parameters': self._parameters,
'custom': self._custom,
}
Expand All @@ -85,22 +85,22 @@ class ProviderConfig:
Configuration related to the provider.
"""

def __init__(self, id: str):
self._id = id
def __init__(self, name: str):
self._name = name

@property
def id(self) -> str:
def name(self) -> str:
"""
The ID of the provider.
The name of the provider.
"""
return self._id
return self._name

def to_dict(self) -> dict:
"""
Render the given provider config as a dictionary object.
"""
return {
'id': self._id,
'name': self._name,
}


Expand Down Expand Up @@ -171,14 +171,14 @@ def config(
provider_config = None
if 'provider' in variation and isinstance(variation['provider'], dict):
provider = variation['provider']
provider_config = ProviderConfig(provider.get('id', ''))
provider_config = ProviderConfig(provider.get('name', ''))

model = None
if 'model' in variation and isinstance(variation['model'], dict):
parameters = variation['model'].get('parameters', None)
custom = variation['model'].get('custom', None)
model = ModelConfig(
id=variation['model']['id'],
name=variation['model']['name'],
parameters=parameters,
custom=custom
)
Expand Down
36 changes: 18 additions & 18 deletions ldai/testing/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def td() -> TestData:
td.flag('model-config')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}},
'provider': {'id': 'fakeProvider'},
'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.5, 'maxTokens': 4096}, 'custom': {'extra-attribute': 'value'}},
'provider': {'name': 'fakeProvider'},
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
},
Expand All @@ -26,7 +26,7 @@ def td() -> TestData:
td.flag('multiple-messages')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.7, 'maxTokens': 8192}},
'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.7, 'maxTokens': 8192}},
'messages': [
{'role': 'system', 'content': 'Hello, {{name}}!'},
{'role': 'user', 'content': 'The day is, {{day}}!'},
Expand All @@ -42,7 +42,7 @@ def td() -> TestData:
td.flag('ctx-interpolation')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}},
'model': {'name': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}},
'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.name}}! Is your last name {{ldctx.last}}?'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
}
Expand All @@ -54,7 +54,7 @@ def td() -> TestData:
td.flag('multi-ctx-interpolation')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}},
'model': {'name': 'fakeModel', 'parameters': {'extra-attribute': 'I can be anything I set my mind/type to'}},
'messages': [{'role': 'system', 'content': 'Hello, {{ldctx.user.name}}! Do you work for {{ldctx.org.shortname}}?'}],
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
}
Expand All @@ -66,7 +66,7 @@ def td() -> TestData:
td.flag('off-config')
.variations(
{
'model': {'id': 'fakeModel', 'parameters': {'temperature': 0.1}},
'model': {'name': 'fakeModel', 'parameters': {'temperature': 0.1}},
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
'_ldMeta': {'enabled': False, 'versionKey': 'abcd'},
}
Expand Down Expand Up @@ -116,19 +116,19 @@ def ldai_client(client: LDClient) -> LDAIClient:

def test_model_config_delegates_to_properties():
model = ModelConfig('fakeModel', parameters={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.name == 'fakeModel'
assert model.get_parameter('extra-attribute') == 'value'
assert model.get_parameter('non-existent') is None

assert model.id == model.get_parameter('id')
assert model.name == model.get_parameter('name')


def test_model_config_handles_custom():
model = ModelConfig('fakeModel', custom={'extra-attribute': 'value'})
assert model.id == 'fakeModel'
assert model.name == 'fakeModel'
assert model.get_parameter('extra-attribute') is None
assert model.get_custom('non-existent') is None
assert model.get_custom('id') is None
assert model.get_custom('name') is None


def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
Expand All @@ -148,7 +148,7 @@ def test_uses_default_on_invalid_flag(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096

Expand All @@ -170,7 +170,7 @@ def test_model_config_interpolation(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096

Expand All @@ -187,7 +187,7 @@ def test_model_config_no_variables(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.5
assert config.model.get_parameter('maxTokens') == 4096

Expand All @@ -200,7 +200,7 @@ def test_provider_config_handling(ldai_client: LDAIClient):
config, _ = ldai_client.config('model-config', context, default_value, variables)

assert config.provider is not None
assert config.provider.id == 'fakeProvider'
assert config.provider.name == 'fakeProvider'


def test_context_interpolation(ldai_client: LDAIClient):
Expand All @@ -218,7 +218,7 @@ def test_context_interpolation(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') is None
assert config.model.get_parameter('maxTokens') is None
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'
Expand All @@ -241,7 +241,7 @@ def test_multi_context_interpolation(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') is None
assert config.model.get_parameter('maxTokens') is None
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'
Expand All @@ -263,7 +263,7 @@ def test_model_config_multiple(ldai_client: LDAIClient):
assert config.enabled is True

assert config.model is not None
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.7
assert config.model.get_parameter('maxTokens') == 8192

Expand All @@ -276,7 +276,7 @@ def test_model_config_disabled(ldai_client: LDAIClient):

assert config.model is not None
assert config.enabled is False
assert config.model.id == 'fakeModel'
assert config.model.name == 'fakeModel'
assert config.model.get_parameter('temperature') == 0.1
assert config.model.get_parameter('maxTokens') is None

Expand Down
Loading