Skip to content

Commit 3b949ba

Browse files
jameszyaoSimsonW
authored andcommitted
feat: add assistant memory
1. add memory to assistant schema and related APIs 2. add description and operationId to action schemas
1 parent 066d8e3 commit 3b949ba

File tree

12 files changed

+210
-33
lines changed

12 files changed

+210
-33
lines changed

examples/assistant/chat_with_assistant.ipynb

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@
6161
" \"paths\": {\n",
6262
" \"/{number}\": {\n",
6363
" \"get\": {\n",
64-
" \"summary\": \"Get fact about a number\",\n",
64+
" \"description\": \"Get a fact about a number\",\n",
65+
" \"operationId\": \"getNumberFact\",\n",
6566
" \"parameters\": [\n",
6667
" {\n",
6768
" \"name\": \"number\",\n",
@@ -109,6 +110,7 @@
109110
"outputs": [],
110111
"source": [
111112
"from taskingai.assistant import Assistant, Chat, AssistantTool, AssistantToolType\n",
113+
"from taskingai.assistant.memory import AssistantMessageWindowMemory\n",
112114
"\n",
113115
"# choose an available chat_completion model from your project\n",
114116
"model_id = \"YOUR_MODEL_ID\"\n",
@@ -117,6 +119,10 @@
117119
" model_id=model_id,\n",
118120
" name=\"My Assistant\",\n",
119121
" description=\"A assistant who knows the meaning of various numbers.\",\n",
122+
" memory=AssistantMessageWindowMemory(\n",
123+
" max_messages=20,\n",
124+
" max_tokens=1000\n",
125+
" ),\n",
120126
" system_prompt_template=[\n",
121127
" \"You know the meaning of various numbers.\",\n",
122128
" \"No matter what the user's language is, you will use the {{langugae}} to explain.\"\n",
@@ -255,8 +261,7 @@
255261
" chat_id=chat.chat_id,\n",
256262
")\n",
257263
"for message in messages:\n",
258-
" print(f\"{message.role}: {message.content.text}\")\n",
259-
" print(\"-\"*100)"
264+
" print(f\"{message.role}: {message.content.text}\")"
260265
],
261266
"metadata": {
262267
"collapsed": false
@@ -265,7 +270,7 @@
265270
},
266271
{
267272
"cell_type": "code",
268-
"execution_count": null,
273+
"execution_count": 15,
269274
"outputs": [],
270275
"source": [
271276
"# delete assistant\n",
@@ -277,16 +282,6 @@
277282
"collapsed": false
278283
},
279284
"id": "ed39836bbfdc7a4e"
280-
},
281-
{
282-
"cell_type": "code",
283-
"execution_count": null,
284-
"outputs": [],
285-
"source": [],
286-
"metadata": {
287-
"collapsed": false
288-
},
289-
"id": "325965f154e14f42"
290285
}
291286
],
292287
"metadata": {

examples/crud/assistant_crud.ipynb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"outputs": [],
2727
"source": [
2828
"from taskingai.assistant import Assistant, Chat\n",
29+
"from taskingai.assistant.memory import AssistantNaiveMemory\n",
2930
"\n",
3031
"# choose an available chat_completion model from your project\n",
3132
"model_id = \"YOUR_MODEL_ID\""
@@ -53,6 +54,7 @@
5354
" name=\"My Assistant\",\n",
5455
" description=\"This is my assistant\",\n",
5556
" system_prompt_template=[\"You are a professional assistant speaking {{language}}.\"],\n",
57+
" memory=AssistantNaiveMemory(),\n",
5658
" tools=[],\n",
5759
" retrievals=[],\n",
5860
" metadata={\"foo\": \"bar\"},\n",
@@ -222,6 +224,18 @@
222224
"metadata": {
223225
"collapsed": false
224226
}
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": null,
231+
"outputs": [],
232+
"source": [
233+
"# delete assistant\n",
234+
"taskingai.assistant.delete_assistant(assistant_id=assistant.assistant_id)"
235+
],
236+
"metadata": {
237+
"collapsed": false
238+
}
225239
}
226240
],
227241
"metadata": {

examples/crud/retrieval_crud.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,10 @@
161161
"source": [
162162
"# create a new collection\n",
163163
"collection: Collection = create_collection()\n",
164+
"print(collection)\n",
164165
"\n",
165166
"# wait for the collection creation to finish\n",
166-
"time.sleep(10)"
167+
"time.sleep(3)"
167168
],
168169
"metadata": {
169170
"collapsed": false

examples/crud/tool_crud.ipynb

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@
6969
" \"paths\": {\n",
7070
" \"/{number}\": {\n",
7171
" \"get\": {\n",
72-
" \"summary\": \"Get fact about a number\",\n",
72+
" \"description\": \"Get fact about a number\",\n",
73+
" \"operationId\": \"getNumberFact\",\n",
7374
" \"parameters\": [\n",
7475
" {\n",
7576
" \"name\": \"number\",\n",
@@ -197,14 +198,6 @@
197198
},
198199
"id": "5a1a36d15055918f"
199200
},
200-
{
201-
"cell_type": "markdown",
202-
"source": [],
203-
"metadata": {
204-
"collapsed": false
205-
},
206-
"id": "2f288bf5d1988887"
207-
},
208201
{
209202
"cell_type": "markdown",
210203
"source": [],

taskingai/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
__title__ = "taskingai"
2-
__version__ = "0.0.2"
2+
__version__ = "0.0.5"

taskingai/assistant/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .assistant import *
22
from .chat import *
3-
from .message import *
3+
from .message import *

taskingai/assistant/assistant.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
from typing import Optional, List, Dict
22

33
from taskingai.client.utils import get_api_instance, ModuleType
4-
from taskingai.client.models import Assistant, AssistantRetrieval, AssistantTool, AssistantToolType, AssistantRetrievalType
5-
from taskingai.client.models import AssistantCreateRequest, AssistantCreateResponse,\
6-
AssistantUpdateRequest, AssistantUpdateResponse,\
7-
AssistantGetResponse, AssistantListResponse
4+
from taskingai.client.models import (
5+
Assistant,
6+
AssistantMemory,
7+
AssistantRetrieval,
8+
AssistantTool,
9+
AssistantToolType,
10+
AssistantRetrievalType
11+
)
12+
13+
from taskingai.client.models import (
14+
AssistantCreateRequest,
15+
AssistantCreateResponse,
16+
AssistantUpdateRequest,
17+
AssistantUpdateResponse,
18+
AssistantGetResponse,
19+
AssistantListResponse,
20+
)
821

922
__all__ = [
1023
"Assistant",
@@ -121,6 +134,7 @@ async def a_get_assistant(assistant_id: str) -> Assistant:
121134

122135
def create_assistant(
123136
model_id: str,
137+
memory: AssistantMemory,
124138
name: Optional[str] = None,
125139
description: Optional[str] = None,
126140
system_prompt_template: Optional[List[str]] = None,
@@ -130,8 +144,9 @@ def create_assistant(
130144
) -> Assistant:
131145
"""
132146
Create an assistant.
133-
147+
134148
:param model_id: The ID of an available chat completion model in your project.
149+
:param memory: The assistant memory.
135150
:param name: The assistant name.
136151
:param description: The assistant description.
137152
:param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}.
@@ -142,10 +157,12 @@ def create_assistant(
142157
"""
143158

144159
api_instance = get_api_instance(ModuleType.assistant)
160+
memory_dict = memory.model_dump()
145161
body = AssistantCreateRequest(
146162
model_id=model_id,
147163
name=name,
148164
description=description,
165+
memory=memory_dict,
149166
system_prompt_template=system_prompt_template,
150167
tools=tools,
151168
retrievals=retrievals,
@@ -158,6 +175,7 @@ def create_assistant(
158175

159176
async def a_create_assistant(
160177
model_id: str,
178+
memory: AssistantMemory,
161179
name: Optional[str] = None,
162180
description: Optional[str] = None,
163181
system_prompt_template: Optional[List[str]] = None,
@@ -169,6 +187,7 @@ async def a_create_assistant(
169187
Create an assistant in async mode.
170188
171189
:param model_id: The ID of an available chat completion model in your project.
190+
:param memory: The assistant memory.
172191
:param name: The assistant name.
173192
:param description: The assistant description.
174193
:param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}.
@@ -179,10 +198,12 @@ async def a_create_assistant(
179198
"""
180199

181200
api_instance = get_api_instance(ModuleType.assistant, async_client=True)
201+
memory_dict = memory.model_dump()
182202
body = AssistantCreateRequest(
183203
model_id=model_id,
184204
name=name,
185205
description=description,
206+
memory=memory_dict,
186207
system_prompt_template=system_prompt_template,
187208
tools=tools,
188209
retrievals=retrievals,
@@ -199,6 +220,7 @@ def update_assistant(
199220
name: Optional[str] = None,
200221
description: Optional[str] = None,
201222
system_prompt_template: Optional[List[str]] = None,
223+
memory: Optional[AssistantMemory] = None,
202224
tools: Optional[List[AssistantTool]] = None,
203225
retrievals: Optional[List[AssistantRetrieval]] = None,
204226
metadata: Optional[Dict[str, str]] = None,
@@ -211,6 +233,7 @@ def update_assistant(
211233
:param name: The assistant name.
212234
:param description: The assistant description.
213235
:param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}.
236+
:param memory: The assistant memory.
214237
:param tools: The assistant tools.
215238
:param retrievals: The assistant retrievals.
216239
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
@@ -223,6 +246,7 @@ def update_assistant(
223246
name=name,
224247
description=description,
225248
system_prompt_template=system_prompt_template,
249+
memory=memory,
226250
tools=tools,
227251
retrievals=retrievals,
228252
metadata=metadata,
@@ -238,6 +262,7 @@ async def a_update_assistant(
238262
name: Optional[str] = None,
239263
description: Optional[str] = None,
240264
system_prompt_template: Optional[List[str]] = None,
265+
memory: Optional[AssistantMemory] = None,
241266
tools: Optional[List[AssistantTool]] = None,
242267
retrievals: Optional[List[AssistantRetrieval]] = None,
243268
metadata: Optional[Dict[str, str]] = None,
@@ -250,6 +275,7 @@ async def a_update_assistant(
250275
:param name: The assistant name.
251276
:param description: The assistant description.
252277
:param system_prompt_template: A list of system prompt chunks where prompt variables are wrapped by curly brackets, e.g. {{variable}}.
278+
:param memory: The assistant memory.
253279
:param tools: The assistant tools.
254280
:param retrievals: The assistant retrievals.
255281
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
@@ -262,6 +288,7 @@ async def a_update_assistant(
262288
name=name,
263289
description=description,
264290
system_prompt_template=system_prompt_template,
291+
memory=memory,
265292
tools=tools,
266293
retrievals=retrievals,
267294
metadata=metadata,

taskingai/assistant/memory.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from taskingai.client.models import (
2+
AssistantMemoryType,
3+
AssistantMemory,
4+
AssistantNaiveMemory,
5+
AssistantZeroMemory,
6+
AssistantMessageWindowMemory,
7+
)
8+
9+
__all__ = [
10+
"AssistantMemory",
11+
"AssistantMemoryType",
12+
"AssistantNaiveMemory",
13+
"AssistantZeroMemory",
14+
"AssistantMessageWindowMemory",
15+
]

taskingai/client/models/entity/assistant/assistant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from typing import List, Dict
2+
from pydantic import field_validator
23
from .._base import TaskingaiBaseModel
4+
from .assistant_memory import *
35
from enum import Enum
46

57
__all__ = [
68
"Assistant",
9+
"AssistantMemoryType",
10+
"AssistantNaiveMemory",
11+
"AssistantZeroMemory",
12+
"AssistantMessageWindowMemory",
713
"AssistantTool",
814
"AssistantRetrieval",
915
"AssistantToolType",
1016
"AssistantRetrievalType",
17+
"AssistantMemory"
1118
]
1219

1320

@@ -37,7 +44,16 @@ class Assistant(TaskingaiBaseModel):
3744
name: str
3845
description: str
3946
system_prompt_template: List[str]
47+
memory: AssistantMemory
4048
tools: List[AssistantTool]
4149
retrievals: List[AssistantRetrieval]
4250
metadata: Dict[str, str]
4351
created_timestamp: int
52+
53+
54+
@field_validator('memory', mode='before')
55+
def validate_memory(cls, memory_dict: Dict):
56+
memory: AssistantMemory = build_assistant_memory(memory_dict)
57+
if not memory:
58+
raise ValueError('Invalid input memory')
59+
return memory
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from enum import Enum
2+
from pydantic import BaseModel, Field
3+
from abc import ABC
4+
from typing import Optional, Dict
5+
6+
7+
class AssistantMemoryType(str, Enum):
8+
zero = "zero"
9+
naive = "naive"
10+
message_window = "message_window"
11+
12+
13+
class AssistantMemory(BaseModel, ABC):
14+
type: AssistantMemoryType = Field(..., description="The type of the memory.")
15+
16+
17+
class AssistantMessageWindowMemory(AssistantMemory):
18+
type: AssistantMemoryType = Field(AssistantMemoryType.message_window, Literal=AssistantMemoryType.message_window)
19+
max_messages: int = Field(...)
20+
max_tokens: int = Field(...)
21+
22+
23+
class AssistantNaiveMemory(AssistantMemory):
24+
type: AssistantMemoryType = Field(AssistantMemoryType.naive, Literal=AssistantMemoryType.naive)
25+
26+
27+
class AssistantZeroMemory(AssistantMemory):
28+
type: AssistantMemoryType = Field(AssistantMemoryType.zero, Literal=AssistantMemoryType.zero)
29+
30+
31+
def build_assistant_memory(memory_dict: Dict) -> Optional[AssistantMemory]:
32+
# Check if the memory dictionary is provided and has the 'type' key
33+
if not memory_dict or 'type' not in memory_dict:
34+
return None
35+
36+
memory_type = memory_dict['type']
37+
38+
if memory_type == AssistantMemoryType.zero.value:
39+
# For zero memory, no additional information is needed
40+
return AssistantZeroMemory()
41+
42+
elif memory_type == AssistantMemoryType.naive.value:
43+
# For naive memory, no additional configuration is needed
44+
return AssistantNaiveMemory()
45+
46+
elif memory_type == AssistantMemoryType.message_window.value:
47+
# For message window memory, additional configuration is needed
48+
max_messages = memory_dict.get('max_messages')
49+
max_tokens = memory_dict.get('max_tokens')
50+
51+
# Validate that required fields are present
52+
if max_messages is None or max_tokens is None:
53+
return None
54+
55+
return AssistantMessageWindowMemory(max_messages=max_messages, max_tokens=max_tokens)
56+
57+
else:
58+
# If the memory type is unknown, return None
59+
return None
60+
61+

0 commit comments

Comments
 (0)