Skip to content

Commit d2a2ea7

Browse files
committed
feat: Add support for provider config
1 parent 3a3e913 commit d2a2ea7

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

ldai/client.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,29 @@ def get_attribute(self, key: str) -> Any:
7373
return self._attributes.get(key)
7474

7575

76+
class ProviderConfig:
77+
"""
78+
Configuration related to the provider.
79+
"""
80+
81+
def __init__(self, id: str):
82+
self._id = id
83+
84+
@property
85+
def id(self) -> str:
86+
"""
87+
The ID of the provider.
88+
"""
89+
return self._id
90+
91+
7692
class AIConfig:
77-
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]]):
93+
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[ModelConfig], messages: Optional[List[LDMessage]], provider: Optional[ProviderConfig] = None):
7894
self.tracker = tracker
7995
self.enabled = enabled
8096
self.model = model
8197
self.messages = messages
98+
self.provider = provider
8299

83100

84101
class LDAIClient:
@@ -124,6 +141,11 @@ def config(
124141
for entry in variation['messages']
125142
]
126143

144+
provider_config = None
145+
if 'provider' in variation and isinstance(variation['provider'], dict):
146+
provider = variation['provider']
147+
provider_config = ProviderConfig(provider.get('id', ''))
148+
127149
model = None
128150
if 'model' in variation:
129151
model = ModelConfig(
@@ -143,7 +165,8 @@ def config(
143165
),
144166
enabled=bool(enabled),
145167
model=model,
146-
messages=messages
168+
messages=messages,
169+
provider=provider_config,
147170
)
148171

149172
def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:

ldai/testing/test_model_config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def td() -> TestData:
1414
.variations(
1515
{
1616
'model': {'modelId': 'fakeModel', 'temperature': 0.5, 'maxTokens': 4096},
17+
'provider': {'id': 'fakeProvider'},
1718
'messages': [{'role': 'system', 'content': 'Hello, {{name}}!'}],
1819
'_ldMeta': {'enabled': True, 'versionKey': 'abcd'},
1920
},
@@ -161,6 +162,17 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
161162
assert config.model.max_tokens == 4096
162163

163164

165+
def test_provider_config_handling(ldai_client: LDAIClient, tracker):
166+
context = Context.builder('user-key').name("Sandy").build()
167+
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
168+
variables = {'name': 'World'}
169+
170+
config = ldai_client.config('model-config', context, default_value, variables)
171+
172+
assert config.provider is not None
173+
assert config.provider.id == 'fakeProvider'
174+
175+
164176
def test_context_interpolation(ldai_client: LDAIClient, tracker):
165177
context = Context.builder('user-key').name("Sandy").build()
166178
default_value = AIConfig(tracker=tracker, enabled=True, model=ModelConfig('fake-model'), messages=[])
@@ -225,6 +237,7 @@ def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):
225237
assert config.enabled is False
226238
assert config.model is None
227239
assert config.messages is None
240+
assert config.provider is None
228241

229242

230243
def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker):
@@ -236,3 +249,4 @@ def test_model_initial_config_enabled(ldai_client: LDAIClient, tracker):
236249
assert config.enabled is True
237250
assert config.model is None
238251
assert config.messages is None
252+
assert config.provider is None

0 commit comments

Comments
 (0)