Skip to content

Commit ff15da7

Browse files
author
yuan.wang
committed
fix pref_string for memos online api
1 parent a47af83 commit ff15da7

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

evaluation/scripts/utils/client.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def search(self, query, user_id, top_k):
234234
"user_id": user_id,
235235
"memory_limit_number": top_k,
236236
"mode": os.getenv("SEARCH_MODE", "fast"),
237+
"include_preference": True,
238+
"pref_top_k": 6,
237239
}
238240
)
239241

@@ -243,10 +245,23 @@ def search(self, query, user_id, top_k):
243245
response = requests.request("POST", url, data=payload, headers=self.headers)
244246
assert response.status_code == 200, response.text
245247
assert json.loads(response.text)["message"] == "ok", response.text
246-
res = json.loads(response.text)["data"]["memory_detail_list"]
247-
for i in res:
248+
text_mem_res = json.loads(response.text)["data"]["memory_detail_list"]
249+
pref_mem_res = json.loads(response.text)["data"]["preference_detail_list"]
250+
for i in text_mem_res:
248251
i.update({"memory": i.pop("memory_value")})
249-
return {"text_mem": [{"memories": res}], "pref_str": ""}
252+
253+
explicit_prefs = [p['preference'] for p in pref_mem_res if p.get('preference_type', '') == 'explicit_preference']
254+
implicit_prefs = [p['preference'] for p in pref_mem_res if p.get('preference_type', '') == 'implicit_preference']
255+
256+
pref_parts = []
257+
if explicit_prefs:
258+
pref_parts.append("Explicit Preference:\n" + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(explicit_prefs)))
259+
if implicit_prefs:
260+
pref_parts.append("Implicit Preference:\n" + "\n".join(f"{i + 1}. {p}" for i, p in enumerate(implicit_prefs)))
261+
262+
pref_string = "\n".join(pref_parts)
263+
264+
return {"text_mem": [{"memories": text_mem_res}], "pref_string": pref_string}
250265
except Exception as e:
251266
if attempt < max_retries - 1:
252267
time.sleep(2**attempt)
@@ -336,19 +351,23 @@ def wait_for_completion(self, task_id):
336351

337352
if __name__ == "__main__":
338353
messages = [
339-
{"role": "user", "content": "杭州西湖有什么好玩的"},
340-
{"role": "assistant", "content": "杭州西湖有好多松鼠,还有断桥"},
354+
# {"role": "user", "content": "杭州西湖有什么好玩的,我喜欢动物"},
355+
# {"role": "assistant", "content": "杭州西湖有好多松鼠, 你喜欢动物的话可以去看松鼠"},
356+
{"role": "user", "content": "我暑假定好去广州旅游,住宿的话有哪些连锁酒店可选?"},
357+
{"role": "assistant", "content": "您可以考虑【七天、全季、希尔顿】等等"},
358+
{"role": "user", "content": "我选七天"},
359+
{"role": "assistant", "content": "好的,有其他问题再问我。"},
341360
]
342-
user_id = "test_user"
361+
user_id = "test_user2"
343362
iso_date = "2023-05-01T00:00:00.000Z"
344363
timestamp = 1682899200
345364
query = "杭州西湖有什么"
346365
top_k = 5
347366

348367
# MEMOS-API
349-
client = MemosApiClient()
368+
client = MemosApiOnlineClient()
350369
for m in messages:
351370
m["created_at"] = iso_date
352-
client.add(messages, user_id, user_id)
371+
# client.add(messages, user_id, user_id)
353372
memories = client.search(query, user_id, top_k)
354373
print(memories)

0 commit comments

Comments
 (0)