22from enum import IntEnum
33import os , sys , signal , queue
44import threading
5- import json
5+ import json , base64
66from typing import Any , Iterable , List , Union
77
88try :
@@ -37,6 +37,10 @@ class PrintType(IntEnum):
3737 PRINT_EVT_ASYNC_COMPLETED = 100 , # last async operation completed (utf8_str is null)
3838 PRINT_EVT_THOUGHT_COMPLETED = 101 , # thought completed
3939
40+ class EmbeddingPurpose (IntEnum ):
41+ Document = 0 , # for document
42+ Query = 1 , # for query
43+
4044class LibChatLLM :
4145
4246 _obj2id = {}
@@ -100,11 +104,15 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
100104 self ._chatllm_show_statistics = self ._lib .chatllm_show_statistics
101105 self ._chatllm_save_session = self ._lib .chatllm_save_session
102106 self ._chatllm_load_session = self ._lib .chatllm_load_session
107+ self ._chatllm_multimedia_msg_prepare = self ._lib .chatllm_multimedia_msg_prepare
108+ self ._chatllm_multimedia_msg_append = self ._lib .chatllm_multimedia_msg_append
109+ self ._chatllm_user_input_multimedia_msg = self ._lib .chatllm_user_input_multimedia_msg
103110
104111 self ._chatllm_async_user_input = self ._lib .chatllm_async_user_input
105112 self ._chatllm_async_ai_continue = self ._lib .chatllm_async_ai_continue
106113 self ._chatllm_async_tool_input = self ._lib .chatllm_async_tool_input
107114 self ._chatllm_async_tool_completion = self ._lib .chatllm_async_tool_completion
115+ self ._chatllm_async_user_input_multimedia_msg = self ._lib .chatllm_async_user_input_multimedia_msg
108116
109117 self ._chatllm_create .restype = c_void_p
110118 self ._chatllm_create .argtypes = []
@@ -123,11 +131,20 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
123131 self ._chatllm_async_ai_continue .restype = c_int
124132 self ._chatllm_async_ai_continue .argtypes = [c_void_p , c_char_p ]
125133
134+ self ._chatllm_multimedia_msg_prepare .argtypes = [c_void_p ]
135+ self ._chatllm_multimedia_msg_append .restype = c_int
136+ self ._chatllm_multimedia_msg_append .argtypes = [c_void_p , c_char_p , c_char_p ]
137+
126138 self ._chatllm_user_input .restype = c_int
127139 self ._chatllm_user_input .argtypes = [c_void_p , c_char_p ]
128140 self ._chatllm_async_user_input .restype = c_int
129141 self ._chatllm_async_user_input .argtypes = [c_void_p , c_char_p ]
130142
143+ self ._chatllm_user_input_multimedia_msg .restype = c_int
144+ self ._chatllm_user_input_multimedia_msg .argtypes = [c_void_p ]
145+ self ._chatllm_async_user_input_multimedia_msg .restype = c_int
146+ self ._chatllm_async_user_input_multimedia_msg .argtypes = [c_void_p ]
147+
131148 self ._chatllm_tool_input .restype = c_int
132149 self ._chatllm_tool_input .argtypes = [c_void_p , c_char_p ]
133150 self ._chatllm_async_tool_input .restype = c_int
@@ -139,7 +156,7 @@ def __init__(self, lib: str = '', model_storage: str = '', init_params: list[str
139156 self ._chatllm_async_tool_completion .argtypes = [c_void_p , c_char_p ]
140157
141158 self ._chatllm_text_embedding .restype = c_int
142- self ._chatllm_text_embedding .argtypes = [c_void_p , c_char_p ]
159+ self ._chatllm_text_embedding .argtypes = [c_void_p , c_char_p , c_int ]
143160
144161 self ._chatllm_text_tokenize .restype = c_int
145162 self ._chatllm_text_tokenize .argtypes = [c_void_p , c_char_p ]
@@ -241,11 +258,44 @@ def start(self, obj: c_void_p, callback_obj: Any) -> int:
241258 def set_ai_prefix (self , obj : c_void_p , prefix : str ) -> int :
242259 return self ._chatllm_set_ai_prefix (obj , c_char_p (prefix .encode ()))
243260
244- def chat (self , obj : c_void_p , user_input : str ) -> int :
245- return self ._chatllm_user_input (obj , c_char_p (user_input .encode ()))
246-
247- def async_chat (self , obj : c_void_p , user_input : str ) -> int :
248- return self ._chatllm_async_user_input (obj , c_char_p (user_input .encode ()))
261+ def _input_multimedia_msg (self , obj : c_void_p , user_input : List [dict | str ]) -> int :
262+ self ._chatllm_multimedia_msg_prepare (obj )
263+ for x in user_input :
264+ if isinstance (x , str ):
265+ self ._chatllm_multimedia_msg_append (obj , c_char_p ('text' ), c_char_p (x ))
266+ elif isinstance (x , dict ):
267+ t = x ['type' ]
268+ if t == 'text' :
269+ data = x ['text' ].encode ()
270+ else :
271+ if 'file' in x :
272+ with open (x ['file' ], 'rb' ) as f :
273+ data = f .read ()
274+ elif 'url' in x :
275+ url : str = x ['url' ]
276+ if url .startswith ('data:' ):
277+ i = url .find ('base64,' )
278+ data = base64 .decodebytes (url [i + 7 :].encode ())
279+ else :
280+ data = model_downloader .download_file_to_bytes (url )
281+ else :
282+ raise Exception (f'unknown message piece: { x } ' )
283+ data = base64 .b64encode (data )
284+ self ._chatllm_multimedia_msg_append (obj , c_char_p (t .encode ()), c_char_p (data ))
285+
286+ def chat (self , obj : c_void_p , user_input : str | List [dict | str ]) -> int :
287+ if isinstance (user_input , str ):
288+ return self ._chatllm_user_input (obj , c_char_p (user_input .encode ()))
289+ elif isinstance (user_input , list ):
290+ self ._input_multimedia_msg (obj , user_input )
291+ return self ._chatllm_user_input_multimedia_msg (obj )
292+
293+ def async_chat (self , obj : c_void_p , user_input : str | List [dict | str ]) -> int :
294+ if isinstance (user_input , str ):
295+ return self ._chatllm_async_user_input (obj , c_char_p (user_input .encode ()))
296+ else :
297+ self ._input_multimedia_msg (obj , user_input )
298+ self ._chatllm_async_user_input_multimedia_msg (obj )
249299
250300 def ai_continue (self , obj : c_void_p , suffix : str ) -> int :
251301 return self ._chatllm_ai_continue (obj , c_char_p (suffix .encode ()))
@@ -268,8 +318,8 @@ def tool_completion(self, obj: c_void_p, user_input: str) -> int:
268318 def text_tokenize (self , obj : c_void_p , text : str ) -> str :
269319 return self ._chatllm_text_tokenize (obj , c_char_p (text .encode ()))
270320
271- def text_embedding (self , obj : c_void_p , text : str ) -> str :
272- return self ._chatllm_text_embedding (obj , c_char_p (text .encode ()))
321+ def text_embedding (self , obj : c_void_p , text : str , purpose : EmbeddingPurpose = EmbeddingPurpose . Document ) -> str :
322+ return self ._chatllm_text_embedding (obj , c_char_p (text .encode ()), c_int ( purpose . value ) )
273323
274324 def qa_rank (self , obj : c_void_p , q : str , a : str ) -> float :
275325 return self ._chatllm_qa_rank (obj , c_char_p (q .encode ()), c_char_p (a .encode ()))
0 commit comments