@@ -234,6 +234,8 @@ def search(self, query, user_id, top_k):
234234 "user_id" : user_id ,
235235 "memory_limit_number" : top_k ,
236236 "mode" : os .getenv ("SEARCH_MODE" , "fast" ),
237+ "include_preference" : True ,
238+ "pref_top_k" : 6 ,
237239 }
238240 )
239241
@@ -243,10 +245,23 @@ def search(self, query, user_id, top_k):
243245 response = requests .request ("POST" , url , data = payload , headers = self .headers )
244246 assert response .status_code == 200 , response .text
245247 assert json .loads (response .text )["message" ] == "ok" , response .text
246- res = json .loads (response .text )["data" ]["memory_detail_list" ]
247- for i in res :
248+ text_mem_res = json .loads (response .text )["data" ]["memory_detail_list" ]
249+ pref_mem_res = json .loads (response .text )["data" ]["preference_detail_list" ]
250+ for i in text_mem_res :
248251 i .update ({"memory" : i .pop ("memory_value" )})
249- return {"text_mem" : [{"memories" : res }], "pref_str" : "" }
252+
253+ explicit_prefs = [p ['preference' ] for p in pref_mem_res if p .get ('preference_type' , '' ) == 'explicit_preference' ]
254+ implicit_prefs = [p ['preference' ] for p in pref_mem_res if p .get ('preference_type' , '' ) == 'implicit_preference' ]
255+
256+ pref_parts = []
257+ if explicit_prefs :
258+ pref_parts .append ("Explicit Preference:\n " + "\n " .join (f"{ i + 1 } . { p } " for i , p in enumerate (explicit_prefs )))
259+ if implicit_prefs :
260+ pref_parts .append ("Implicit Preference:\n " + "\n " .join (f"{ i + 1 } . { p } " for i , p in enumerate (implicit_prefs )))
261+
262+ pref_string = "\n " .join (pref_parts )
263+
264+ return {"text_mem" : [{"memories" : text_mem_res }], "pref_string" : pref_string }
250265 except Exception as e :
251266 if attempt < max_retries - 1 :
252267 time .sleep (2 ** attempt )
@@ -336,19 +351,23 @@ def wait_for_completion(self, task_id):
336351
337352if __name__ == "__main__" :
338353 messages = [
339- {"role" : "user" , "content" : "杭州西湖有什么好玩的" },
340- {"role" : "assistant" , "content" : "杭州西湖有好多松鼠,还有断桥" },
354+ # {"role": "user", "content": "杭州西湖有什么好玩的,我喜欢动物"},
355+ # {"role": "assistant", "content": "杭州西湖有好多松鼠, 你喜欢动物的话可以去看松鼠"},
356+ {"role" : "user" , "content" : "我暑假定好去广州旅游,住宿的话有哪些连锁酒店可选?" },
357+ {"role" : "assistant" , "content" : "您可以考虑【七天、全季、希尔顿】等等" },
358+ {"role" : "user" , "content" : "我选七天" },
359+ {"role" : "assistant" , "content" : "好的,有其他问题再问我。" },
341360 ]
342- user_id = "test_user "
361+ user_id = "test_user2 "
343362 iso_date = "2023-05-01T00:00:00.000Z"
344363 timestamp = 1682899200
345364 query = "杭州西湖有什么"
346365 top_k = 5
347366
348367 # MEMOS-API
349- client = MemosApiClient ()
368+ client = MemosApiOnlineClient ()
350369 for m in messages :
351370 m ["created_at" ] = iso_date
352- client .add (messages , user_id , user_id )
371+ # client.add(messages, user_id, user_id)
353372 memories = client .search (query , user_id , top_k )
354373 print (memories )
0 commit comments