Skip to content

Commit 55f34fe

Browse files
committed
fix!: Remove max_tokens and temperature as top level model config keys
1 parent 319f64d commit 55f34fe

File tree

2 files changed

+22
-53
lines changed

2 files changed

+22
-53
lines changed

ldai/client.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,13 @@ class ModelConfig:
2121
Configuration related to the model.
2222
"""
2323

24-
def __init__(self, id: str, temperature: Optional[float] = None,
25-
max_tokens: Optional[int] = None, attributes: dict = {}):
24+
def __init__(self, id: str, parameters: dict = {}):
2625
"""
2726
:param id: The ID of the model.
28-
:param temperature: Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
29-
:param max_tokens: The maximum number of tokens.
30-
:param attributes: Additional model-specific attributes.
27+
:param parameters: Additional model-specific parameters.
3128
"""
3229
self._id = id
33-
self._temperature = temperature
34-
self._max_tokens = max_tokens
35-
self._attributes = attributes
30+
self._parameters = parameters
3631

3732
@property
3833
def id(self) -> str:
@@ -41,36 +36,17 @@ def id(self) -> str:
4136
"""
4237
return self._id
4338

44-
@property
45-
def temperature(self) -> Optional[float]:
46-
""""
47-
Turning parameter for randomness versus determinism. Exact effect will be determined by the model.
48-
"""
49-
return self._temperature
50-
51-
@property
52-
def max_tokens(self) -> Optional[int]:
53-
"""
54-
The maximum number of tokens.
55-
"""
56-
57-
return self._max_tokens
58-
59-
def get_attribute(self, key: str) -> Any:
39+
def get_parameter(self, key: str) -> Any:
6040
"""
61-
Retrieve model-specific attributes.
41+
Retrieve model-specific parameters.
6242
6343
Accessing a named, typed attribute (e.g. id) will result in the call
6444
being delegated to the appropriate property.
6545
"""
6646
if key == 'id':
6747
return self.id
68-
if key == 'temperature':
69-
return self.temperature
70-
if key == 'maxTokens':
71-
return self.max_tokens
7248

73-
return self._attributes.get(key)
49+
return self._parameters.get(key)
7450

7551

7652
class ProviderConfig:
@@ -150,9 +126,7 @@ def config(
150126
if 'model' in variation:
151127
model = ModelConfig(
152128
id=variation['model']['modelId'],
153-
temperature=variation['model'].get('temperature'),
154-
max_tokens=variation['model'].get('maxTokens'),
155-
attributes=variation['model'],
129+
parameters=variation['model'],
156130
)
157131

158132
enabled = variation.get('_ldMeta', {}).get('enabled', False)

ldai/testing/test_model_config.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,12 @@ def ldai_client(client: LDClient) -> LDAIClient:
109109

110110

111111
def test_model_config_delegates_to_properties():
112-
model = ModelConfig('fakeModel', temperature=0.5, max_tokens=4096, attributes={'extra-attribute': 'value'})
112+
model = ModelConfig('fakeModel', parameters={'extra-attribute': 'value'})
113113
assert model.id == 'fakeModel'
114-
assert model.temperature == 0.5
115-
assert model.max_tokens == 4096
116-
assert model.get_attribute('extra-attribute') == 'value'
117-
assert model.get_attribute('non-existent') is None
114+
assert model.get_parameter('extra-attribute') == 'value'
115+
assert model.get_parameter('non-existent') is None
118116

119-
assert model.id == model.get_attribute('id')
120-
assert model.temperature == model.get_attribute('temperature')
121-
assert model.max_tokens == model.get_attribute('maxTokens')
122-
assert model.max_tokens != model.get_attribute('max_tokens')
117+
assert model.id == model.get_parameter('id')
123118

124119

125120
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
@@ -141,8 +136,8 @@ def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
141136

142137
assert config.model is not None
143138
assert config.model.id == 'fakeModel'
144-
assert config.model.temperature == 0.5
145-
assert config.model.max_tokens == 4096
139+
assert config.model.get_parameter('temperature') == 0.5
140+
assert config.model.get_parameter('maxTokens') == 4096
146141

147142

148143
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
@@ -158,8 +153,8 @@ def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
158153

159154
assert config.model is not None
160155
assert config.model.id == 'fakeModel'
161-
assert config.model.temperature == 0.5
162-
assert config.model.max_tokens == 4096
156+
assert config.model.get_parameter('temperature') == 0.5
157+
assert config.model.get_parameter('maxTokens') == 4096
163158

164159

165160
def test_provider_config_handling(ldai_client: LDAIClient, tracker):
@@ -189,9 +184,9 @@ def test_context_interpolation(ldai_client: LDAIClient, tracker):
189184

190185
assert config.model is not None
191186
assert config.model.id == 'fakeModel'
192-
assert config.model.temperature is None
193-
assert config.model.max_tokens is None
194-
assert config.model.get_attribute('extra-attribute') == 'I can be anything I set my mind/type to'
187+
assert config.model.get_parameter('temperature') is None
188+
assert config.model.get_parameter('maxTokens') is None
189+
assert config.model.get_parameter('extra-attribute') == 'I can be anything I set my mind/type to'
195190

196191

197192
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
@@ -211,8 +206,8 @@ def test_model_config_multiple(ldai_client: LDAIClient, tracker):
211206

212207
assert config.model is not None
213208
assert config.model.id == 'fakeModel'
214-
assert config.model.temperature == 0.7
215-
assert config.model.max_tokens == 8192
209+
assert config.model.get_parameter('temperature') == 0.7
210+
assert config.model.get_parameter('maxTokens') == 8192
216211

217212

218213
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
@@ -224,8 +219,8 @@ def test_model_config_disabled(ldai_client: LDAIClient, tracker):
224219
assert config.model is not None
225220
assert config.enabled is False
226221
assert config.model.id == 'fakeModel'
227-
assert config.model.temperature == 0.1
228-
assert config.model.max_tokens is None
222+
assert config.model.get_parameter('temperature') == 0.1
223+
assert config.model.get_parameter('maxTokens') is None
229224

230225

231226
def test_model_initial_config_disabled(ldai_client: LDAIClient, tracker):

0 commit comments

Comments
 (0)