1+ from io import BytesIO
2+ import base64
3+ import threading
4+ from concurrent .futures import ThreadPoolExecutor , Future
5+ import numpy as np # Import numpy
6+ from typing import Union
7+ HAS_IMPORTS = True
8+ try :
9+ from PIL import Image
10+ import numpy as np
11+
12+ except ImportError :
13+ HAS_IMPORTS = False
14+
15+ try :
16+ import requests
17+ from requests .adapters import HTTPAdapter , Retry
18+ except ImportError :
19+ HAS_IMPORTS = False
20+
21+ class InfinityVisionAPI :
22+ def __init__ (self , url : str = "https://infinity-multimodal.modal.michaelfeil.eu" , format = "base64" ) -> None :
23+ req = requests .post (
24+ url + "/embeddings" ,
25+ json = { # get shape of output
26+ "model" : "michaelfeil/colqwen2-v0.1" ,
27+ "input" : ["test" ],
28+ "encoding_format" : "float" ,
29+ "modality" : "text"
30+ }
31+ )
32+ req .raise_for_status ()
33+ self .url = url
34+ self .hidden_dim = np .array (req .json ()["data" ][0 ]["embedding" ]).shape [- 1 ]
35+ self .format = format
36+ self .tp = ThreadPoolExecutor ()
37+ self .tp .__enter__ ()
38+
39+ self .sem = threading .Semaphore (64 )
40+ self .session = requests .Session ()
41+ adapter = HTTPAdapter (max_retries = Retry (total = 10 , backoff_factor = 0.5 ))
42+ self .session .mount ("https://" , adapter )
43+ self .session .mount ("http://" , adapter )
44+
45+ def _image_payload (self , images : list ["Image.Image" ]) -> list [str ]:
46+ if not HAS_IMPORTS :
47+ raise ImportError ("PIL is required to use this class" )
48+ b64_strs = []
49+ for image in images :
50+ buffered = BytesIO ()
51+ if not hasattr (image , "save" ):
52+ raise ValueError ("Image must be a PIL Image" )
53+ image .save (buffered , format = "JPEG" )
54+ img_str = base64 .b64encode (buffered .getvalue ()).decode ("utf-8" )
55+ b64_strs .append (f"data:image/jpeg;base64,{ img_str } " )
56+ return b64_strs
57+
58+ def _text_payload (self , texts : list [str ]) -> list [str ]:
59+ return texts
60+
61+ def health (self ) -> bool :
62+ req = self .session .get (f"{ self .url } /health" )
63+ req .raise_for_status ()
64+ return req .status_code == 200
65+
66+ def _request (self , model : str , images_or_text : list [Union ["Image.Image" , str ]]) -> dict :
67+ if all (hasattr (item , "save" ) for item in images_or_text ):
68+ payload = self ._image_payload (images_or_text )
69+ modality = "image"
70+ elif all (isinstance (item , str ) for item in images_or_text ):
71+ payload = self ._text_payload (images_or_text )
72+ modality = "text"
73+ else :
74+ raise ValueError ("Images and text cannot be mixed in a single request" )
75+
76+ embeddings_req = self .session .post (
77+ f"{ self .url } /embeddings" ,
78+ json = {
79+ "model" : model ,
80+ "input" : payload ,
81+ "encoding_format" : self .format ,
82+ "modality" : modality
83+ }
84+ )
85+ embeddings_req .raise_for_status ()
86+ embeddings = embeddings_req .json ()
87+
88+ if self .format == "base64" :
89+ embeddings_decoded = [
90+ np .frombuffer (
91+ base64 .b64decode (e ["embedding" ]), dtype = np .float32
92+ ).reshape (- 1 , self .hidden_dim )
93+ for e in embeddings ["data" ]
94+ ]
95+ else :
96+ embeddings_decoded = [
97+ np .array (e ["embedding" ])
98+ for e in embeddings ["data" ]
99+ ]
100+ return embeddings_decoded , embeddings ["usage" ]["total_tokens" ]
101+
102+ def embed (self , model : str , sentences : list [str ]) -> Future [list ]:
103+ self .health ()
104+ with self .sem :
105+ return self .tp .submit (self ._request , model = model , images_or_text = sentences )
106+
107+ def image_embed (self , model : str , images : list ["Image.Image" ]) -> Future [list ]:
108+ self .health () # Call once instead of per image
109+ with self .sem :
110+ return self .tp .submit (self ._request , model = model , images_or_text = images )
111+
112+ def test_colpali ():
113+ colpali = InfinityVisionAPI ()
114+ future = colpali .embed ("michaelfeil/colqwen2-v0.1" , ["test" ])
115+ embeddings , total_tokens = future .result ()
116+ print (embeddings , total_tokens )
117+
118+ if __name__ == "__main__" :
119+ test_colpali ()
0 commit comments