Skip to content

Commit f50e3da

Browse files
committed
add vision client
1 parent 9735ed5 commit f50e3da

File tree

3 files changed

+121
-0
lines changed

3 files changed

+121
-0
lines changed

.tool-versions

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
poetry 1.8.2
2+
python 3.11.9
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from io import BytesIO
2+
import base64
3+
import threading
4+
from concurrent.futures import ThreadPoolExecutor, Future
5+
import numpy as np # Import numpy
6+
from typing import Union
7+
HAS_IMPORTS = True
8+
try:
9+
from PIL import Image
10+
import numpy as np
11+
12+
except ImportError:
13+
HAS_IMPORTS = False
14+
15+
try:
16+
import requests
17+
from requests.adapters import HTTPAdapter, Retry
18+
except ImportError:
19+
HAS_IMPORTS = False
20+
21+
class InfinityVisionAPI:
22+
def __init__(self, url: str = "https://infinity-multimodal.modal.michaelfeil.eu", format="base64") -> None:
23+
req = requests.post(
24+
url + "/embeddings",
25+
json={ # get shape of output
26+
"model": "michaelfeil/colqwen2-v0.1",
27+
"input": ["test"],
28+
"encoding_format": "float",
29+
"modality": "text"
30+
}
31+
)
32+
req.raise_for_status()
33+
self.url = url
34+
self.hidden_dim = np.array(req.json()["data"][0]["embedding"]).shape[-1]
35+
self.format = format
36+
self.tp = ThreadPoolExecutor()
37+
self.tp.__enter__()
38+
39+
self.sem = threading.Semaphore(64)
40+
self.session = requests.Session()
41+
adapter = HTTPAdapter(max_retries=Retry(total=10, backoff_factor=0.5))
42+
self.session.mount("https://", adapter)
43+
self.session.mount("http://", adapter)
44+
45+
def _image_payload(self, images: list["Image.Image"]) -> list[str]:
46+
if not HAS_IMPORTS:
47+
raise ImportError("PIL is required to use this class")
48+
b64_strs = []
49+
for image in images:
50+
buffered = BytesIO()
51+
if not hasattr(image, "save"):
52+
raise ValueError("Image must be a PIL Image")
53+
image.save(buffered, format="JPEG")
54+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
55+
b64_strs.append(f"data:image/jpeg;base64,{img_str}")
56+
return b64_strs
57+
58+
def _text_payload(self, texts: list[str]) -> list[str]:
59+
return texts
60+
61+
def health(self) -> bool:
62+
req = self.session.get(f"{self.url}/health")
63+
req.raise_for_status()
64+
return req.status_code == 200
65+
66+
def _request(self, model: str, images_or_text: list[Union["Image.Image", str]]) -> dict:
67+
if all(hasattr(item, "save") for item in images_or_text):
68+
payload = self._image_payload(images_or_text)
69+
modality = "image"
70+
elif all(isinstance(item, str) for item in images_or_text):
71+
payload = self._text_payload(images_or_text)
72+
modality = "text"
73+
else:
74+
raise ValueError("Images and text cannot be mixed in a single request")
75+
76+
embeddings_req = self.session.post(
77+
f"{self.url}/embeddings",
78+
json={
79+
"model": model,
80+
"input": payload,
81+
"encoding_format": self.format,
82+
"modality": modality
83+
}
84+
)
85+
embeddings_req.raise_for_status()
86+
embeddings = embeddings_req.json()
87+
88+
if self.format == "base64":
89+
embeddings_decoded = [
90+
np.frombuffer(
91+
base64.b64decode(e["embedding"]), dtype=np.float32
92+
).reshape(-1, self.hidden_dim)
93+
for e in embeddings["data"]
94+
]
95+
else:
96+
embeddings_decoded = [
97+
np.array(e["embedding"])
98+
for e in embeddings["data"]
99+
]
100+
return embeddings_decoded, embeddings["usage"]["total_tokens"]
101+
102+
def embed(self, model: str, sentences: list[str]) -> Future[list]:
103+
self.health()
104+
with self.sem:
105+
return self.tp.submit(self._request, model=model, images_or_text=sentences)
106+
107+
def image_embed(self, model: str, images: list["Image.Image"]) -> Future[list]:
108+
self.health() # Call once instead of per image
109+
with self.sem:
110+
return self.tp.submit(self._request, model=model, images_or_text=images)
111+
112+
def test_colpali():
113+
colpali = InfinityVisionAPI()
114+
future = colpali.embed("michaelfeil/colqwen2-v0.1", ["test"])
115+
embeddings, total_tokens = future.result()
116+
print(embeddings, total_tokens)
117+
118+
if __name__ == "__main__":
119+
test_colpali()

libs/client_infinity/run_generate_with_hook.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ pip install openapi-python-client==0.21.1
3838
--custom-template-path=./template
3939

4040
# copy the readme to docs
41+
cp ./template/vision_client.py ./infinity_client/infinity_client/vision_client.py
4142
cp ./infinity_client/README.md ./../../docs/docs/client_infinity.md
4243
# Cleanup will be called due to the trap

0 commit comments

Comments
 (0)