Skip to content

Commit 065ae9c

Browse files
committed
feat: add example plugin for API-backed picture description with token usage
Signed-off-by: FrigaZzz <[email protected]>
1 parent 0610d01 commit 065ae9c

File tree

8 files changed

+983
-2
lines changed

8 files changed

+983
-2
lines changed

docs/concepts/plugins.md

Lines changed: 516 additions & 0 deletions
Large diffs are not rendered by default.

docs/examples/rag_mongodb.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,6 @@
452452
"source": [
453453
"## Part 4: Perform RAG on parsed articles\n",
454454
"\n",
455-
"Weaviate's `generate` module allows you to perform RAG over your embedded data without having to use a separate framework.\n",
456-
"\n",
457455
"We specify a prompt that includes the field we want to search through in the database (in this case it's `text`), a query that includes our search term, and the number of retrieved results to use in the generation."
458456
]
459457
},
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from api_usage.models.picture_description_api_model import (
2+
PictureDescriptionApiModelWithUsage,
3+
)
4+
5+
6+
def picture_description():
7+
return {"picture_description": [PictureDescriptionApiModelWithUsage]}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
2+
3+
from pydantic import (
4+
AnyUrl,
5+
BaseModel,
6+
ConfigDict,
7+
Field,
8+
)
9+
10+
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
11+
12+
13+
class PictureDescriptionApiOptionsWithUsage(PictureDescriptionBaseOptions):
14+
"""DescriptionAnnotation."""
15+
16+
kind: ClassVar[Literal["api_usage"]] = "api_usage"
17+
18+
url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions")
19+
headers: Dict[str, str] = {}
20+
params: Dict[str, Any] = {}
21+
timeout: float = 20
22+
concurrency: int = 1
23+
24+
prompt: str = "Describe this image in a few sentences."
25+
provenance: str = ""
26+
# Key inside the response 'usage' (or similar) which will be used to extract
27+
# the token/response text. Example: 'content' or 'text'. If None, no
28+
# token extraction will be performed by default.
29+
token_extract_key: Optional[str] = Field(
30+
None,
31+
description=(
32+
"Key in the response usage dict whose value contains the token/"
33+
"response to extract. For example 'content' or 'text'."
34+
),
35+
)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import base64
2+
import json
3+
import logging
4+
from io import BytesIO
5+
from typing import Dict, List, Optional, Tuple
6+
7+
import requests
8+
from PIL import Image
9+
from pydantic import AnyUrl
10+
11+
from docling.datamodel.base_models import OpenAiApiResponse
12+
from docling.models.utils.generation_utils import GenerationStopper
13+
14+
_log = logging.getLogger(__name__)
15+
16+
17+
def api_image_request(
18+
image: Image.Image,
19+
prompt: str,
20+
url: AnyUrl,
21+
timeout: float = 20,
22+
headers: Optional[Dict[str, str]] = None,
23+
token_extract_key: Optional[str] = None,
24+
**params,
25+
) -> Tuple[str, Optional[dict]]:
26+
"""Send an image+prompt to an OpenAI-compatible API and return (text, usage).
27+
28+
If no usage data is available, the second tuple element will be None.
29+
"""
30+
img_io = BytesIO()
31+
image.save(img_io, "PNG")
32+
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
33+
messages = [
34+
{
35+
"role": "user",
36+
"content": [
37+
{
38+
"type": "image_url",
39+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
40+
},
41+
{"type": "text", "text": prompt},
42+
],
43+
}
44+
]
45+
46+
payload = {"messages": messages, **params}
47+
headers = headers or {}
48+
49+
r = requests.post(str(url), headers=headers, json=payload, timeout=timeout)
50+
if not r.ok:
51+
_log.error(f"Error calling the API. Response was {r.text}")
52+
r.raise_for_status()
53+
54+
# Try to parse JSON body
55+
try:
56+
resp_json = r.json()
57+
except Exception:
58+
api_resp = OpenAiApiResponse.model_validate_json(r.text)
59+
generated_text = api_resp.choices[0].message.content.strip()
60+
return generated_text, None
61+
62+
usage = None
63+
if isinstance(resp_json, dict):
64+
usage = resp_json.get("usage")
65+
66+
# Extract generated text using common OpenAI shapes
67+
generated_text = ""
68+
try:
69+
generated_text = resp_json["choices"][0]["message"]["content"].strip()
70+
except Exception:
71+
try:
72+
generated_text = resp_json["choices"][0].get("text", "")
73+
if isinstance(generated_text, str):
74+
generated_text = generated_text.strip()
75+
except Exception:
76+
try:
77+
api_resp = OpenAiApiResponse.model_validate_json(r.text)
78+
generated_text = api_resp.choices[0].message.content.strip()
79+
except Exception:
80+
generated_text = ""
81+
82+
# If an explicit token_extract_key is provided and found in usage, use it
83+
if token_extract_key and isinstance(usage, dict) and token_extract_key in usage:
84+
extracted = usage.get(token_extract_key)
85+
generated_text = (
86+
str(extracted).strip() if extracted is not None else generated_text
87+
)
88+
89+
return generated_text, usage
90+
91+
92+
def api_image_request_streaming(
93+
image: Image.Image,
94+
prompt: str,
95+
url: AnyUrl,
96+
*,
97+
timeout: float = 20,
98+
headers: Optional[Dict[str, str]] = None,
99+
generation_stoppers: List[GenerationStopper] = [],
100+
**params,
101+
) -> str:
102+
"""
103+
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
104+
Parses SSE lines: 'data: {json}\n\n', terminated by 'data: [DONE]'.
105+
Accumulates text and calls stopper.should_stop(window) as chunks arrive.
106+
If stopper triggers, the HTTP connection is closed to abort server-side generation.
107+
"""
108+
img_io = BytesIO()
109+
image.save(img_io, "PNG")
110+
image_b64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
111+
112+
messages = [
113+
{
114+
"role": "user",
115+
"content": [
116+
{
117+
"type": "image_url",
118+
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
119+
},
120+
{"type": "text", "text": prompt},
121+
],
122+
}
123+
]
124+
125+
payload = {"messages": messages, "stream": True, **params}
126+
_log.debug(f"API streaming request payload: {json.dumps(payload, indent=2)}")
127+
128+
hdrs = {"Accept": "text/event-stream", **(headers or {})}
129+
if "temperature" in params:
130+
hdrs["X-Temperature"] = str(params["temperature"])
131+
132+
# Stream the HTTP response
133+
with requests.post(
134+
str(url), headers=hdrs, json=payload, timeout=timeout, stream=True
135+
) as r:
136+
if not r.ok:
137+
_log.error(
138+
f"Error calling the API {url} in streaming mode. Response was {r.text}"
139+
)
140+
r.raise_for_status()
141+
142+
full_text: List[str] = []
143+
for raw_line in r.iter_lines(decode_unicode=True):
144+
if not raw_line: # keep-alives / blank lines
145+
continue
146+
if not raw_line.startswith("data:"):
147+
# Some proxies inject comments; ignore anything not starting with 'data:'
148+
continue
149+
150+
data = raw_line[len("data:") :].strip()
151+
if data == "[DONE]":
152+
break
153+
154+
try:
155+
obj = json.loads(data)
156+
except json.JSONDecodeError:
157+
_log.debug("Skipping non-JSON SSE chunk: %r", data[:200])
158+
continue
159+
160+
try:
161+
delta = obj["choices"][0].get("delta") or {}
162+
piece = delta.get("content") or ""
163+
except (KeyError, IndexError) as e:
164+
_log.debug("Unexpected SSE chunk shape: %s", e)
165+
piece = ""
166+
167+
if piece:
168+
full_text.append(piece)
169+
for stopper in generation_stoppers:
170+
lookback = max(1, stopper.lookback_tokens())
171+
window = "".join(full_text)[-lookback:]
172+
if stopper.should_stop(window):
173+
return "".join(full_text)
174+
175+
return "".join(full_text)
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from collections.abc import Iterable
2+
from concurrent.futures import ThreadPoolExecutor
3+
from pathlib import Path
4+
from typing import List, Literal, Optional, Type, Union
5+
6+
from api_usage.datamodel.pipeline_options.picture_description_api_model_with_usage import (
7+
PictureDescriptionApiOptionsWithUsage,
8+
)
9+
from api_usage.datamodel.utils.api_image_request_with_usage import api_image_request
10+
from docling_core.types.doc import DoclingDocument, NodeItem, PictureItem
11+
from docling_core.types.doc.document import (
12+
BaseAnnotation,
13+
) # TODO: move import to docling_core.types.doc
14+
from PIL import Image
15+
16+
from docling.datamodel.accelerator_options import AcceleratorOptions
17+
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
18+
from docling.exceptions import OperationNotAllowed
19+
from docling.models.base_model import ItemAndImageEnrichmentElement
20+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
21+
22+
23+
class DescriptionAnnotationWithUsage(BaseAnnotation):
24+
"""DescriptionAnnotation."""
25+
26+
kind: Literal["description"] = "description"
27+
text: str
28+
provenance: str
29+
token_usage: Optional[dict] = None
30+
31+
32+
class PictureDescriptionApiModelWithUsage(PictureDescriptionBaseModel):
33+
# elements_batch_size = 4
34+
35+
@classmethod
36+
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
37+
return PictureDescriptionApiOptionsWithUsage
38+
39+
def __init__(
40+
self,
41+
enabled: bool,
42+
enable_remote_services: bool,
43+
artifacts_path: Optional[Union[Path, str]],
44+
options: PictureDescriptionApiOptionsWithUsage,
45+
accelerator_options: AcceleratorOptions,
46+
):
47+
super().__init__(
48+
enabled=enabled,
49+
enable_remote_services=enable_remote_services,
50+
artifacts_path=artifacts_path,
51+
options=options,
52+
accelerator_options=accelerator_options,
53+
)
54+
self.options: PictureDescriptionApiOptionsWithUsage
55+
self.concurrency = self.options.concurrency
56+
57+
if self.enabled:
58+
if not enable_remote_services:
59+
raise OperationNotAllowed(
60+
"Connections to remote services is only allowed when set explicitly. "
61+
"pipeline_options.enable_remote_services=True."
62+
)
63+
64+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
65+
# Note: technically we could make a batch request here,
66+
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
67+
def _api_request(image):
68+
# Pass token_extract_key so api_image_request can return token usage
69+
return api_image_request(
70+
image=image,
71+
prompt=self.options.prompt,
72+
url=self.options.url,
73+
timeout=self.options.timeout,
74+
headers=self.options.headers,
75+
token_extract_key=getattr(self.options, "token_extract_key", None),
76+
**self.options.params,
77+
)
78+
79+
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
80+
yield from executor.map(_api_request, images)
81+
82+
def __call__(
83+
self,
84+
doc: DoclingDocument,
85+
element_batch: Iterable[ItemAndImageEnrichmentElement],
86+
) -> Iterable[NodeItem]:
87+
if not self.enabled:
88+
for element in element_batch:
89+
yield element.item
90+
return
91+
92+
images: List[Image.Image] = []
93+
elements: List[PictureItem] = []
94+
for el in element_batch:
95+
assert isinstance(el.item, PictureItem)
96+
describe_image = True
97+
# Don't describe the image if it's smaller than the threshold
98+
if len(el.item.prov) > 0:
99+
prov = el.item.prov[0] # PictureItems have at most a single provenance
100+
page = doc.pages.get(prov.page_no)
101+
if page is not None:
102+
page_area = page.size.width * page.size.height
103+
if page_area > 0:
104+
area_fraction = prov.bbox.area() / page_area
105+
if area_fraction < self.options.picture_area_threshold:
106+
describe_image = False
107+
if describe_image:
108+
elements.append(el.item)
109+
images.append(el.image)
110+
111+
outputs = self._annotate_images(images)
112+
113+
for item, output in zip(elements, outputs):
114+
# api_image_request now may return (text, usage) or plain text;
115+
# normalize to tuple
116+
if isinstance(output, tuple):
117+
text, usage = output
118+
else:
119+
text, usage = output, None
120+
121+
item.annotations.append(
122+
DescriptionAnnotationWithUsage(
123+
text=text, provenance=self.provenance, token_usage=usage
124+
)
125+
)
126+
yield item

0 commit comments

Comments
 (0)