Skip to content

Commit 68bd1fd

Browse files
committed
feat: add demo plot infer
1 parent 7b0f805 commit 68bd1fd

File tree

20 files changed

+586
-22
lines changed

20 files changed

+586
-22
lines changed

src/client/memobase/core/entry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def insert(self, blob_data: Blob) -> str:
119119
json=blob_data.to_request(),
120120
)
121121
)
122+
print(r)
122123
return r.data["id"]
123124

124125
def get(self, blob_id: str) -> Blob:

src/server/api/api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,5 +245,11 @@ def custom_openapi():
245245
)(api_layer.context.get_user_context)
246246

247247

248+
router.post(
249+
"/users/roleplay/proactive/{user_id}",
250+
tags=["roleplay"],
251+
# openapi_extra=API_X_CODE_DOCS["POST /users/roleplay/proactive/{user_id}"],
252+
)(api_layer.roleplay.infer_proactive_topics)
253+
248254
app.include_router(router)
249255
app.add_middleware(api_layer.middleware.AuthMiddleware)

src/server/api/memobase_server/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.33"
1+
__version__ = "0.0.34"
22

33
__author__ = "memobase.io"
44
__url__ = "https://github.com/memodb-io/memobase"

src/server/api/memobase_server/api_layer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
from . import event
88
from . import context
99
from . import middleware
10+
from . import roleplay
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import json
2+
from ..controllers import full as controllers
3+
from ..controllers.modal.roleplay import proactive_topics
4+
from ..models.blob import BlobType
5+
from ..models.utils import Promise, CODE
6+
from ..models import response as res
7+
from fastapi import Request
8+
from fastapi import Body, Path, Query
9+
10+
11+
async def infer_proactive_topics(
12+
request: Request,
13+
user_id: str = Path(..., description="The ID of the user"),
14+
topk: int = Query(
15+
None, description="Number of profiles to retrieve, default is all"
16+
),
17+
max_token_size: int = Query(
18+
None,
19+
description="Max token size of returned profile content, default is all",
20+
),
21+
prefer_topics: list[str] = Query(
22+
None,
23+
description="Rank prefer topics at first to try to keep them in filtering, default order is by updated time",
24+
),
25+
only_topics: list[str] = Query(
26+
None,
27+
description="Only return profiles with these topics, default is all",
28+
),
29+
max_subtopic_size: int = Query(
30+
None,
31+
description="Max subtopic size of the same topic in returned profile, default is all",
32+
),
33+
topic_limits_json: str = Query(
34+
None,
35+
description='Set specific subtopic limits for topics in JSON, for example {"topic1": 3, "topic2": 5}. The limits in this param will override `max_subtopic_size`.',
36+
),
37+
body: res.ProactiveTopicRequest = Body(..., description="The body of the request"),
38+
) -> res.ProactiveTopicResponse:
39+
"""Get the real-time user profiles for long term memory"""
40+
project_id = request.state.memobase_project_id
41+
topic_limits_json = topic_limits_json or "{}"
42+
try:
43+
topic_limits = res.StrIntData(data=json.loads(topic_limits_json)).data
44+
except Exception as e:
45+
return Promise.reject(
46+
CODE.BAD_REQUEST, f"Invalid JSON requests: {e}"
47+
).to_response(res.UserProfileResponse)
48+
p = await proactive_topics.process_messages(
49+
user_id,
50+
project_id,
51+
body.messages,
52+
body.agent_context,
53+
prefer_topics,
54+
topk,
55+
max_token_size,
56+
only_topics,
57+
max_subtopic_size,
58+
topic_limits,
59+
)
60+
return p.to_response(res.ProactiveTopicResponse)

src/server/api/memobase_server/controllers/event.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ async def get_user_events(
2828
).filter(UserEvent.event_data.has_key("event_tip"))
2929
user_events = query.order_by(UserEvent.created_at.desc()).limit(topk).all()
3030
if user_events is None:
31-
return Promise.reject(
32-
CODE.NOT_FOUND,
33-
f"No user events found for user {user_id}",
34-
)
31+
return Promise.resolve(UserEventsData(events=[]))
3532
results = [
3633
{
3734
"id": ue.id,
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from ....models.utils import Promise, CODE
2+
from ....env import CONFIG, LOG, ProfileConfig
3+
from ....utils import get_encoded_tokens, truncate_string
4+
from ....llms import llm_complete
5+
from ....models.blob import OpenAICompatibleMessage
6+
from .types import PROMPTS, ChatInterest
7+
from ..utils import try_json_loads
8+
9+
10+
async def detect_chat_interest(
11+
project_id: str,
12+
messages: list[OpenAICompatibleMessage],
13+
profile_config: ProfileConfig,
14+
) -> Promise[ChatInterest]:
15+
USE_LANGUAGE = "zh"
16+
prompt = PROMPTS[USE_LANGUAGE]["detect_interest"]
17+
18+
r = await llm_complete(
19+
project_id,
20+
prompt.get_input(messages),
21+
system_prompt=prompt.get_prompt(),
22+
temperature=0.2, # precise
23+
model=CONFIG.best_llm_model,
24+
**prompt.get_kwargs(),
25+
)
26+
if not r.ok():
27+
return r
28+
content = r.data()
29+
data = try_json_loads(content)
30+
print(data)
31+
if data is None:
32+
return Promise.reject(
33+
CODE.INTERNAL_SERVER_ERROR, "Unable to parse the LLM json response"
34+
)
35+
return Promise.resolve(data)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import re
2+
from ....models.utils import Promise, CODE
3+
from ....env import CONFIG, LOG, ProfileConfig
4+
from ....utils import get_encoded_tokens, truncate_string
5+
from ....llms import llm_complete
6+
from ....models.blob import OpenAICompatibleMessage
7+
from ....models.response import UserStatusesData
8+
from .types import PROMPTS, InferPlot
9+
10+
11+
def extract_plot_output(content: str):
12+
themes = re.search(r"<themes>(.*?)</themes>", content, re.DOTALL)
13+
overview = re.search(r"<overview>(.*?)</overview>", content, re.DOTALL)
14+
timeline = re.search(r"<timeline>(.*?)</timeline>", content, re.DOTALL)
15+
return (
16+
themes.group(1).strip() if themes else None,
17+
overview.group(1).strip() if overview else None,
18+
timeline.group(1).strip() if timeline else None,
19+
)
20+
21+
22+
async def predict_new_topics(
23+
project_id: str,
24+
messages: list[OpenAICompatibleMessage],
25+
latest_statuses: UserStatusesData,
26+
user_context: str,
27+
agent_context: str,
28+
profile_config: ProfileConfig,
29+
max_before_old_topics: int = 5,
30+
) -> Promise[InferPlot]:
31+
USE_LANGUAGE = "zh"
32+
prompt = PROMPTS[USE_LANGUAGE]["infer_plot"]
33+
34+
latest_plots = [
35+
ld.attributes["new_topic"]["overview"]
36+
for ld in latest_statuses.statuses
37+
if "new_topic" in ld.attributes
38+
][:max_before_old_topics]
39+
print(
40+
"THINK",
41+
prompt.get_input(agent_context, user_context, latest_plots, messages),
42+
)
43+
r = await llm_complete(
44+
project_id,
45+
prompt.get_input(agent_context, user_context, latest_plots, messages),
46+
system_prompt=prompt.get_prompt(),
47+
temperature=0.2, # precise
48+
model=CONFIG.thinking_llm_model,
49+
**prompt.get_kwargs(),
50+
no_cache=True,
51+
)
52+
if not r.ok():
53+
return r
54+
content = r.data()
55+
print(content)
56+
themes, overview, timeline = extract_plot_output(content)
57+
return Promise.resolve(dict(themes=themes, overview=overview, timeline=timeline))
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from ....env import ContanstTable, CONFIG, LOG
2+
from ...status import append_user_status, get_user_statuses
3+
from ...profile import get_user_profiles, truncate_profiles
4+
from ...project import get_project_profile_config
5+
from ....models.blob import OpenAICompatibleMessage
6+
from ....models.utils import Promise
7+
from ....models.response import ProactiveTopicData
8+
from ...profile import get_user_profiles, truncate_profiles
9+
from .detect_interest import detect_chat_interest
10+
from .predict_new_topics import predict_new_topics
11+
12+
# from .types import
13+
14+
15+
def pack_timeline_prompt(timeline: str, language: str) -> str:
16+
if language == "zh":
17+
return f"## 下面是你的剧本,如果我没有主动提供话题的话,参考下面剧情推动我们的对话:\n{timeline}##"
18+
else:
19+
return f"## Here is your script, if I don't provide a topic, please refer to the following plot to drive our conversation: \n{timeline}##"
20+
21+
22+
async def process_messages(
23+
user_id: str,
24+
project_id: str,
25+
messages: list[OpenAICompatibleMessage],
26+
agent_context: str = None,
27+
prefer_topics: list[str] = None,
28+
topk: int = None,
29+
max_token_size: int = None,
30+
only_topics: list[str] = None,
31+
max_subtopic_size: int = None,
32+
topic_limits: dict[str, int] = None,
33+
) -> Promise[ProactiveTopicData]:
34+
p = await get_project_profile_config(project_id)
35+
if not p.ok():
36+
return p
37+
project_profiles = p.data()
38+
USE_LANGUAGE = "zh"
39+
# USE_LANGUAGE = project_profiles.language or CONFIG.language
40+
41+
interest = await detect_chat_interest(
42+
project_id,
43+
messages,
44+
profile_config=project_profiles,
45+
)
46+
if not interest.ok():
47+
return interest
48+
interest_data = interest.data()
49+
# if interest_data["action"] != "new_topic":
50+
# await append_user_status(
51+
# user_id,
52+
# project_id,
53+
# ContanstTable.roleplay_plot_status,
54+
# {
55+
# "interest": interest_data,
56+
# },
57+
# )
58+
# return Promise.resolve(ProactiveTopicData(action="continue"))
59+
latests_statuses = await get_user_statuses(
60+
user_id, project_id, type=ContanstTable.roleplay_plot_status
61+
)
62+
if not latests_statuses.ok():
63+
return latests_statuses
64+
latests_statuses_data = latests_statuses.data()
65+
66+
p = await get_user_profiles(user_id, project_id)
67+
if not p.ok():
68+
return p
69+
p = await truncate_profiles(
70+
p.data(),
71+
prefer_topics=prefer_topics,
72+
topk=topk,
73+
max_token_size=max_token_size,
74+
only_topics=only_topics,
75+
max_subtopic_size=max_subtopic_size,
76+
topic_limits=topic_limits,
77+
)
78+
if not p.ok():
79+
return p
80+
user_profiles_data = p.data()
81+
use_user_profiles = user_profiles_data.profiles
82+
user_context = "\n".join(
83+
[
84+
f"{p.attributes.get('topic')}::{p.attributes.get('sub_topic')}: {p.content}"
85+
for p in use_user_profiles
86+
]
87+
)
88+
89+
p = await predict_new_topics(
90+
project_id,
91+
messages,
92+
latests_statuses_data,
93+
user_context,
94+
agent_context,
95+
project_profiles,
96+
)
97+
if not p.ok():
98+
return p
99+
plot = p.data()
100+
await append_user_status(
101+
user_id,
102+
project_id,
103+
ContanstTable.roleplay_plot_status,
104+
{
105+
"interest": interest_data,
106+
"new_topic": plot,
107+
"chats": [m.model_dump() for m in messages],
108+
},
109+
)
110+
111+
return Promise.resolve(
112+
ProactiveTopicData(
113+
action="new_topic",
114+
topic_prompt=pack_timeline_prompt(plot["timeline"], USE_LANGUAGE),
115+
)
116+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import TypedDict
2+
from ....prompts.roleplay import zh_detect_interest, zh_infer_plot
3+
4+
ChatInterest = TypedDict("ChatInterest", {"status": str, "action": str})
5+
6+
InferPlot = TypedDict(
7+
"InferPlot", {"themes": str | None, "overview": str | None, "timeline": str | None}
8+
)
9+
10+
PROMPTS = {
11+
"en": {},
12+
"zh": {
13+
"detect_interest": zh_detect_interest,
14+
"infer_plot": zh_infer_plot,
15+
},
16+
}

0 commit comments

Comments
 (0)