Skip to content

Commit 6dd20b8

Browse files
authored
feat: picture description via langchain llm (#22)
* feat: picture description via langchain llm Signed-off-by: Michele Dolfi <[email protected]> * add docstrings Signed-off-by: Michele Dolfi <[email protected]> * cleanup Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Michele Dolfi <[email protected]>
1 parent cb9aa96 commit 6dd20b8

File tree

5 files changed

+450
-9
lines changed

5 files changed

+450
-9
lines changed

examples/docling_picture_description.ipynb

Lines changed: 167 additions & 0 deletions
Large diffs are not rendered by default.

langchain_docling/_plugins.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Register Docling plugins."""
2+
3+
4+
def picture_description():
5+
"""Picture description plugins."""
6+
from langchain_docling.picture_description import PictureDescriptionLangChainModel
7+
8+
return {
9+
"picture_description": [
10+
PictureDescriptionLangChainModel,
11+
]
12+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Picture description model using LangChain primitives."""
2+
3+
import base64
4+
import io
5+
from collections.abc import Iterable
6+
from pathlib import Path
7+
from typing import ClassVar, Literal, Optional, Type, Union
8+
9+
from docling.datamodel.accelerator_options import AcceleratorOptions
10+
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
11+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
12+
from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
13+
from langchain_core.language_models.chat_models import BaseChatModel
14+
from PIL import Image
15+
16+
17+
class PictureDescriptionLangChainOptions(PictureDescriptionBaseOptions):
18+
"""Options for the PictureDescriptionLangChainModel."""
19+
20+
kind: ClassVar[Literal["langchain"]] = "langchain"
21+
llm: BaseChatModel
22+
prompt: str = "Describe this document picture in a few sentences."
23+
provenance: Optional[str] = None
24+
25+
26+
class PictureDescriptionLangChainModel(
27+
PictureDescriptionBaseModel, HuggingFaceModelDownloadMixin
28+
):
29+
"""Implementation of a PictureDescription model using LangChain."""
30+
31+
@classmethod
32+
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
33+
"""Define the option type for the factory."""
34+
return PictureDescriptionLangChainOptions
35+
36+
def __init__(
37+
self,
38+
enabled: bool,
39+
enable_remote_services: bool,
40+
artifacts_path: Optional[Union[Path, str]],
41+
options: PictureDescriptionLangChainOptions,
42+
accelerator_options: AcceleratorOptions,
43+
):
44+
"""Initialize PictureDescriptionLangChainModel."""
45+
super().__init__(
46+
enabled=enabled,
47+
enable_remote_services=enable_remote_services,
48+
artifacts_path=artifacts_path,
49+
options=options,
50+
accelerator_options=accelerator_options,
51+
)
52+
self.options: PictureDescriptionLangChainOptions
53+
54+
if self.enabled:
55+
self.llm = self.options.llm
56+
self.provenance = "langchain"
57+
if self.options.provenance:
58+
self.provenance += f"-{self.options.provenance}"
59+
60+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
61+
"""Annotate the images with the LangChain model."""
62+
# Create input messages
63+
batch_messages = []
64+
65+
for image in images:
66+
buffered = io.BytesIO()
67+
image.save(buffered, format="PNG")
68+
image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
69+
batch_messages.append(
70+
[
71+
{
72+
"role": "user",
73+
"content": [
74+
{"type": "text", "text": self.options.prompt},
75+
{
76+
"type": "image_url",
77+
"image_url": {
78+
"url": f"data:image/png;base64,{image_data}"
79+
},
80+
},
81+
],
82+
}
83+
]
84+
)
85+
86+
responses = self.llm.batch(batch_messages)
87+
for resp in responses:
88+
yield resp.text()

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ dev = [
6363
"pytest~=8.3",
6464
"pytest-cov>=6.1.1",
6565
"python-semantic-release~=7.32",
66+
"langchain-openai>=0.2.12",
6667
]
6768

6869
[tool.uv]
@@ -72,6 +73,9 @@ default-groups = "all"
7273
[tool.setuptools.packages.find]
7374
include = ["langchain_docling*"]
7475

76+
[project.entry-points."docling"]
77+
langchain_docling = "langchain_docling._plugins"
78+
7579
[tool.black]
7680
line-length = 88
7781
target-version = ["py39", "py310"]

0 commit comments

Comments
 (0)