Skip to content

Commit dcbd88f

Browse files
committed
Inline generation config into top level config
1 parent 7a0eb0c commit dcbd88f

File tree

2 files changed

+25
-39
lines changed

2 files changed

+25
-39
lines changed

ldai/client.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,12 @@ class LDMessage:
1616
content: str
1717

1818

19-
@dataclass
20-
class AIConfigData:
21-
model: Optional[dict]
22-
prompt: Optional[List[LDMessage]]
23-
24-
2519
class AIConfig:
26-
def __init__(self, config: AIConfigData, tracker: LDAIConfigTracker, enabled: bool):
27-
self.config = config
20+
def __init__(self, tracker: LDAIConfigTracker, enabled: bool, model: Optional[dict], prompt: Optional[List[LDMessage]]):
2821
self.tracker = tracker
2922
self.enabled = enabled
23+
self.model = model
24+
self.prompt = prompt
3025

3126

3227
class LDAIClient:
@@ -73,14 +68,15 @@ def model_config(
7368

7469
enabled = variation.get('_ldMeta', {}).get('enabled', False)
7570
return AIConfig(
76-
config=AIConfigData(model=variation['model'], prompt=prompt),
7771
tracker=LDAIConfigTracker(
7872
self.client,
7973
variation.get('_ldMeta', {}).get('versionKey', ''),
8074
key,
8175
context,
8276
),
8377
enabled=bool(enabled),
78+
model=variation['model'],
79+
prompt=prompt
8480
)
8581

8682
def __interpolate_template(self, template: str, variables: Dict[str, Any]) -> str:

ldai/testing/test_model_config.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ldclient.integrations.test_data import TestData
44
from ldclient.testing.builders import *
55

6-
from ldai.client import AIConfig, AIConfigData, LDAIClient, LDMessage
6+
from ldai.client import AIConfig, LDAIClient, LDMessage
77
from ldai.tracker import LDAIConfigTracker
88

99

@@ -85,77 +85,67 @@ def ldai_client(client: LDClient) -> LDAIClient:
8585
def test_model_config_interpolation(ldai_client: LDAIClient, tracker):
8686
context = Context.create('user-key')
8787
default_value = AIConfig(
88-
config=AIConfigData(
89-
model={'modelId': 'fakeModel'},
90-
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
91-
),
9288
tracker=tracker,
9389
enabled=True,
90+
model={'modelId': 'fakeModel'},
91+
prompt=[LDMessage(role='system', content='Hello, {{name}}!')],
9492
)
9593
variables = {'name': 'World'}
9694

9795
config = ldai_client.model_config('model-config', context, default_value, variables)
9896

99-
assert config.config.prompt is not None
100-
assert len(config.config.prompt) > 0
101-
assert config.config.prompt[0].content == 'Hello, World!'
97+
assert config.prompt is not None
98+
assert len(config.prompt) > 0
99+
assert config.prompt[0].content == 'Hello, World!'
102100
assert config.enabled is True
103101

104102

105103
def test_model_config_no_variables(ldai_client: LDAIClient, tracker):
106104
context = Context.create('user-key')
107-
default_value = AIConfig(
108-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
109-
)
105+
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
110106

111107
config = ldai_client.model_config('model-config', context, default_value, {})
112108

113-
assert config.config.prompt is not None
114-
assert len(config.config.prompt) > 0
115-
assert config.config.prompt[0].content == 'Hello, !'
109+
assert config.prompt is not None
110+
assert len(config.prompt) > 0
111+
assert config.prompt[0].content == 'Hello, !'
116112
assert config.enabled is True
117113

118114

119115
def test_context_interpolation(ldai_client: LDAIClient, tracker):
120116
context = Context.builder('user-key').name("Sandy").build()
121-
default_value = AIConfig(
122-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
123-
)
117+
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
124118
variables = {'name': 'World'}
125119

126120
config = ldai_client.model_config(
127121
'ctx-interpolation', context, default_value, variables
128122
)
129123

130-
assert config.config.prompt is not None
131-
assert len(config.config.prompt) > 0
132-
assert config.config.prompt[0].content == 'Hello, Sandy!'
124+
assert config.prompt is not None
125+
assert len(config.prompt) > 0
126+
assert config.prompt[0].content == 'Hello, Sandy!'
133127
assert config.enabled is True
134128

135129

136130
def test_model_config_multiple(ldai_client: LDAIClient, tracker):
137131
context = Context.create('user-key')
138-
default_value = AIConfig(
139-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=True
140-
)
132+
default_value = AIConfig(tracker=tracker, enabled=True, model={}, prompt=[])
141133
variables = {'name': 'World', 'day': 'Monday'}
142134

143135
config = ldai_client.model_config(
144136
'multiple-prompt', context, default_value, variables
145137
)
146138

147-
assert config.config.prompt is not None
148-
assert len(config.config.prompt) > 0
149-
assert config.config.prompt[0].content == 'Hello, World!'
150-
assert config.config.prompt[1].content == 'The day is, Monday!'
139+
assert config.prompt is not None
140+
assert len(config.prompt) > 0
141+
assert config.prompt[0].content == 'Hello, World!'
142+
assert config.prompt[1].content == 'The day is, Monday!'
151143
assert config.enabled is True
152144

153145

154146
def test_model_config_disabled(ldai_client: LDAIClient, tracker):
155147
context = Context.create('user-key')
156-
default_value = AIConfig(
157-
config=AIConfigData(model={}, prompt=[]), tracker=tracker, enabled=False
158-
)
148+
default_value = AIConfig(tracker=tracker, enabled=False, model={}, prompt=[])
159149

160150
config = ldai_client.model_config('off-config', context, default_value, {})
161151

0 commit comments

Comments
 (0)