Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 59 additions & 45 deletions ldai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,15 @@ class ModelConfig:
Configuration related to the model.
"""

def __init__(self, id: str, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, attributes: dict = {}):
def __init__(self, id: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None):
"""
:param id: The ID of the model.
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
:param max_tokens: The maximum number of tokens.
:param attributes: Additional model-specific attributes.
:param parameters: Additional model-specific parameters.
:param custom: Additional customer provided data.
"""
self._id = id
self._temperature = temperature
self._max_tokens = max_tokens
self._attributes = attributes
self._parameters = parameters
self._custom = custom

@property
def id(self) -> str:
Expand All @@ -41,44 +38,54 @@ def id(self) -> str:
"""
return self._id

@property
def temperature(self) -> Optional[float]:
""""
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
def get_parameter(self, key: str) -> Any:
"""
return self._temperature

@property
def max_tokens(self) -> Optional[int]:
"""
The maximum number of tokens.
"""

return self._max_tokens

def get_attribute(self, key: str) -> Any:
"""
Retrieve model-specific attributes.
Retrieve model-specific parameters.

Accessing a named, typed attribute (e.g. id) will result in the call
being delegated to the appropriate property.
"""
if key == 'id':
return self.id
if key == 'temperature':
return self.temperature
if key == 'maxTokens':
return self.max_tokens

return self._attributes.get(key)
if self._parameters is None:
return None

return self._parameters.get(key)

def get_custom(self, key: str) -> Any:
"""
Retrieve customer provided data.
"""
if self._custom is None:
return None

return self._custom.get(key)


class ProviderConfig:
"""
Configuration related to the provider.
"""

def __init__(self, id: str):
self._id = id

@property
def id(self) -> str:
"""
The ID of the provider.
"""
return self._id


class AIConfig:
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], prompt: Optional[List[LDMessage]]):
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None):
self.tracker = tracker
self.enabled = enabled
self.model = model
self.prompt = prompt
self.messages = messages
self.provider = provider


class LDAIClient:
Expand All @@ -87,15 +94,15 @@ class LDAIClient:
def __init__(self, client: LDClient):
self.client = client

def model_config(
def config(
self,
key: str,
context: Context,
default_value: AIConfig,
variables: Optional[Dict[str, Any]] = None,
) -> AIConfig:
"""
Get the value of a model configuration asynchronously.
Get the value of a model configuration.

:param key: The key of the model configuration.
:param context: The context to evaluate the model configuration in.
Expand All @@ -110,27 +117,33 @@ def model_config(
all_variables.update(variables)
all_variables['ldctx'] = context

prompt = None
if 'prompt' in variation and isinstance(variation['prompt'], list) and all(
isinstance(entry, dict) for entry in variation['prompt']
messages = None
if 'messages' in variation and isinstance(variation['messages'], list) and all(
isinstance(entry, dict) for entry in variation['messages']
):
prompt = [
messages = [
LDMessage(
role=entry['role'],
content=self.__interpolate_template(
entry['content'], all_variables
),
)
for entry in variation['prompt']
for entry in variation['messages']
]

provider_config = None
if 'provider' in variation and isinstance(variation['provider'], dict):
provider = variation['provider']
provider_config = ProviderConfig(provider.get('id', ''))

model = None
if 'model' in variation:
if 'model' in variation and isinstance(variation['model'], dict):
parameters = variation['model'].get('parameters', None)
custom = variation['model'].get('custom', None)
model = ModelConfig(
id=variation['model']['modelId'],
temperature=variation['model'].get('temperature'),
max_tokens=variation['model'].get('maxTokens'),
attributes=variation['model'],
id=variation['model']['id'],
parameters=parameters,
custom=custom
)

enabled = variation.get('_ldMeta', {}).get('enabled', False)
Expand All @@ -143,7 +156,8 @@ def model_config(
),
enabled=bool(enabled),
model=model,
prompt=prompt
messages=messages,
provider=provider_config,
)

def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:
Expand Down
Loading