Skip to content

Commit c8bffca

Browse files
committed
prompt_template is mind's param
1 parent 647fced commit c8bffca

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

examples/base_usage.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@
3535

3636
# or separately
3737
datasource = client.datasources.create(postgres_config)
38-
mind2 = client.minds.create(name='mind_name', datasources=[datasource] )
38+
mind = client.minds.create(name='mind_name', datasources=[datasource] )
39+
40+
# with prompt template
41+
mind = client.minds.create(name='mind_name', prompt_template='You are codding assistant')
3942

4043
# or add to existed mind
41-
mind3 = client.minds.create(name='mind_name')
44+
mind = client.minds.create(name='mind_name')
4245
# by config
43-
mind2.add_datasource(postgres_config)
46+
mind.add_datasource(postgres_config)
4447
# or by datasource
45-
mind2.add_datasource(datasource)
48+
mind.add_datasource(datasource)
4649

4750

4851
# --- managing minds ---

minds/minds.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def __init__(
2828
self.name = name
2929
self.model_name = model_name
3030
self.provider = provider
31+
if parameters is None:
32+
parameters = {}
33+
self.prompt_template = parameters.pop('prompt_template', None)
3134
self.parameters = parameters
3235
self.created_at = created_at
3336
self.updated_at = updated_at
@@ -39,8 +42,9 @@ def update(
3942
name: str = None,
4043
model_name: str = None,
4144
provider=None,
45+
prompt_template=None,
46+
datasources=None,
4247
parameters=None,
43-
datasources=None
4448
):
4549
data = {}
4650

@@ -57,8 +61,13 @@ def update(
5761
data['model_name'] = model_name
5862
if provider is not None:
5963
data['provider'] = provider
60-
if parameters is not None:
61-
data['parameters'] = parameters
64+
if parameters is None:
65+
parameters = {}
66+
67+
data['parameters'] = parameters
68+
69+
if prompt_template is not None:
70+
data['parameters']['prompt_template'] = prompt_template
6271

6372
self.api.patch(
6473
f'/projects/{self.project}/minds/{self.name}',
@@ -185,8 +194,9 @@ def create(
185194
self, name,
186195
model_name=None,
187196
provider=None,
188-
parameters=None,
197+
prompt_template=None,
189198
datasources=None,
199+
parameters=None,
190200
replace=False,
191201
) -> Mind:
192202

@@ -206,6 +216,9 @@ def create(
206216

207217
if parameters is None:
208218
parameters = {}
219+
220+
if prompt_template is not None:
221+
parameters['prompt_template'] = prompt_template
209222
if 'prompt_template' not in parameters:
210223
parameters['prompt_template'] = DEFAULT_PROMPT_TEMPLATE
211224

tests/integration/test_base_flow.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,13 @@ def test_minds():
8888
mind_name,
8989
replace=True,
9090
datasources=[ds.name, ds2_cfg],
91-
parameters={
92-
'prompt_template': prompt1
93-
}
91+
prompt_template=prompt1
9492
)
9593

9694
# get
9795
mind = client.minds.get(mind_name)
9896
assert len(mind.datasources) == 2
99-
assert mind.parameters['prompt_template'] == prompt1
97+
assert mind.prompt_template == prompt1
10098

10199
# list
102100
mind_list = client.minds.list()
@@ -106,17 +104,15 @@ def test_minds():
106104
mind.update(
107105
name=mind_name2,
108106
datasources=[ds.name],
109-
parameters={
110-
'prompt_template': prompt2
111-
}
107+
prompt_template=prompt2
112108
)
113109
with pytest.raises(ObjectNotFound):
114110
# this name not exists
115111
client.minds.get(mind_name)
116112

117113
mind = client.minds.get(mind_name2)
118114
assert len(mind.datasources) == 1
119-
assert mind.parameters['prompt_template'] == prompt2
115+
assert mind.prompt_template == prompt2
120116

121117
# add datasource
122118
mind.add_datasource(ds2_cfg)

tests/unit/test_unit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ def test_create(self, mock_del, mock_post, mock_get):
121121
client = get_client()
122122

123123
mind_name = 'test_mind'
124-
parameters = {'prompt_template': 'always agree'}
124+
prompt_template = 'always agree'
125125
datasources = ['my_ds']
126126
provider = 'openai'
127127

128128
response_mock(mock_get, self.mind_json)
129129
create_params = {
130130
'name': mind_name,
131-
'parameters': parameters,
131+
'prompt_template': prompt_template,
132132
'datasources': datasources
133133
}
134134
mind = client.minds.create(**create_params)
@@ -137,8 +137,9 @@ def check_mind_created(mind, mock_post, create_params):
137137
args, kwargs = mock_post.call_args
138138
assert args[0].endswith('/api/projects/mindsdb/minds')
139139
request = kwargs['json']
140-
for k, v in create_params.items():
141-
assert request[k] == v
140+
for key in ('name', 'datasources', 'provider', 'model_name'),:
141+
assert request.get(key) == create_params.get(key)
142+
assert create_params.get('prompt_template') == request.get('parameters', {}).get('prompt_template')
142143

143144
self.compare_mind(mind, self.mind_json)
144145

@@ -147,7 +148,7 @@ def check_mind_created(mind, mock_post, create_params):
147148
# with replace
148149
create_params = {
149150
'name': mind_name,
150-
'parameters': parameters,
151+
'prompt_template': prompt_template,
151152
'provider': provider,
152153
}
153154
mind = client.minds.create(replace=True, **create_params)

0 commit comments

Comments
 (0)