|
| 1 | +import cocoindex |
| 2 | +import io |
| 3 | +import torch |
| 4 | +import functools |
| 5 | +import PIL |
| 6 | + |
| 7 | +from dataclasses import dataclass |
| 8 | +from pypdf import PdfReader |
| 9 | +from transformers import CLIPModel, CLIPProcessor |
| 10 | +from typing import Literal |
| 11 | + |
| 12 | + |
| 13 | +QDRANT_GRPC_URL = "http://localhost:6334" |
| 14 | +QDRANT_COLLECTION_IMAGE = "PdfElementsEmbeddingImage" |
| 15 | +QDRANT_COLLECTION_TEXT = "PdfElementsEmbeddingText" |
| 16 | + |
| 17 | +CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" |
| 18 | +CLIP_MODEL_DIMENSION = 768 |
| 19 | +ClipVectorType = cocoindex.Vector[cocoindex.Float32, Literal[CLIP_MODEL_DIMENSION]] |
| 20 | + |
| 21 | +IMG_THUMBNAIL_SIZE = (512, 512) |
| 22 | + |
| 23 | + |
| 24 | +@functools.cache |
| 25 | +def get_clip_model() -> tuple[CLIPModel, CLIPProcessor]: |
| 26 | + model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) |
| 27 | + processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME) |
| 28 | + return model, processor |
| 29 | + |
| 30 | + |
| 31 | +@cocoindex.op.function(cache=True, behavior_version=1, gpu=True) |
| 32 | +def clip_embed_image(img_bytes: bytes) -> ClipVectorType: |
| 33 | + """ |
| 34 | + Convert image to embedding using CLIP model. |
| 35 | + """ |
| 36 | + model, processor = get_clip_model() |
| 37 | + image = PIL.Image.open(io.BytesIO(img_bytes)).convert("RGB") |
| 38 | + inputs = processor(images=image, return_tensors="pt") |
| 39 | + with torch.no_grad(): |
| 40 | + features = model.get_image_features(**inputs) |
| 41 | + return features[0].tolist() |
| 42 | + |
| 43 | + |
| 44 | +def clip_embed_query(text: str) -> ClipVectorType: |
| 45 | + """ |
| 46 | + Embed the caption using CLIP model. |
| 47 | + """ |
| 48 | + model, processor = get_clip_model() |
| 49 | + inputs = processor(text=[text], return_tensors="pt", padding=True) |
| 50 | + with torch.no_grad(): |
| 51 | + features = model.get_text_features(**inputs) |
| 52 | + return features[0].tolist() |
| 53 | + |
| 54 | + |
| 55 | +@cocoindex.transform_flow() |
| 56 | +def embed_text( |
| 57 | + text: cocoindex.DataSlice[str], |
| 58 | +) -> cocoindex.DataSlice[cocoindex.Vector[cocoindex.Float32]]: |
| 59 | + """ |
| 60 | + Embed the text using a SentenceTransformer model. |
| 61 | + This is a shared logic between indexing and querying, so extract it as a function.""" |
| 62 | + return text.transform( |
| 63 | + cocoindex.functions.SentenceTransformerEmbed( |
| 64 | + model="sentence-transformers/all-MiniLM-L6-v2" |
| 65 | + ) |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +@dataclass |
| 70 | +class PdfImage: |
| 71 | + name: str |
| 72 | + data: bytes |
| 73 | + |
| 74 | + |
| 75 | +@dataclass |
| 76 | +class PdfPage: |
| 77 | + page_number: int |
| 78 | + text: str |
| 79 | + images: list[PdfImage] |
| 80 | + |
| 81 | + |
| 82 | +@cocoindex.op.function() |
| 83 | +def extract_pdf_elements(content: bytes) -> list[PdfPage]: |
| 84 | + """ |
| 85 | + Extract texts and images from a PDF file. |
| 86 | + """ |
| 87 | + reader = PdfReader(io.BytesIO(content)) |
| 88 | + result = [] |
| 89 | + for i, page in enumerate(reader.pages): |
| 90 | + text = page.extract_text() |
| 91 | + images = [] |
| 92 | + for image in page.images: |
| 93 | + img = image.image |
| 94 | + if img is None: |
| 95 | + continue |
| 96 | + # Skip very small images. |
| 97 | + if img.width < 16 or img.height < 16: |
| 98 | + continue |
| 99 | + thumbnail = io.BytesIO() |
| 100 | + img.thumbnail(IMG_THUMBNAIL_SIZE) |
| 101 | + img.save(thumbnail, img.format or "PNG") |
| 102 | + images.append(PdfImage(name=image.name, data=thumbnail.getvalue())) |
| 103 | + result.append(PdfPage(page_number=i + 1, text=text, images=images)) |
| 104 | + return result |
| 105 | + |
| 106 | + |
| 107 | +qdrant_connection = cocoindex.add_auth_entry( |
| 108 | + "qdrant_connection", |
| 109 | + cocoindex.targets.QdrantConnection(grpc_url=QDRANT_GRPC_URL), |
| 110 | +) |
| 111 | + |
| 112 | + |
| 113 | +@cocoindex.flow_def(name="PdfElementsEmbedding") |
| 114 | +def multi_format_indexing_flow( |
| 115 | + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope |
| 116 | +) -> None: |
| 117 | + """ |
| 118 | + Define an example flow that embeds files into a vector database. |
| 119 | + """ |
| 120 | + data_scope["documents"] = flow_builder.add_source( |
| 121 | + cocoindex.sources.LocalFile( |
| 122 | + path="source_files", included_patterns=["*.pdf"], binary=True |
| 123 | + ) |
| 124 | + ) |
| 125 | + |
| 126 | + text_output = data_scope.add_collector() |
| 127 | + image_output = data_scope.add_collector() |
| 128 | + with data_scope["documents"].row() as doc: |
| 129 | + doc["pages"] = doc["content"].transform(extract_pdf_elements) |
| 130 | + with doc["pages"].row() as page: |
| 131 | + page["chunks"] = page["text"].transform( |
| 132 | + cocoindex.functions.SplitRecursively( |
| 133 | + custom_languages=[ |
| 134 | + cocoindex.functions.CustomLanguageSpec( |
| 135 | + language_name="text", |
| 136 | + separators_regex=[ |
| 137 | + r"\n(\s*\n)+", |
| 138 | + r"[\.!\?]\s+", |
| 139 | + r"\n", |
| 140 | + r"\s+", |
| 141 | + ], |
| 142 | + ) |
| 143 | + ] |
| 144 | + ), |
| 145 | + language="text", |
| 146 | + chunk_size=600, |
| 147 | + chunk_overlap=100, |
| 148 | + ) |
| 149 | + with page["chunks"].row() as chunk: |
| 150 | + chunk["embedding"] = chunk["text"].call(embed_text) |
| 151 | + text_output.collect( |
| 152 | + id=cocoindex.GeneratedField.UUID, |
| 153 | + filename=doc["filename"], |
| 154 | + page=page["page_number"], |
| 155 | + text=chunk["text"], |
| 156 | + embedding=chunk["embedding"], |
| 157 | + ) |
| 158 | + with page["images"].row() as image: |
| 159 | + image["embedding"] = image["data"].transform(clip_embed_image) |
| 160 | + image_output.collect( |
| 161 | + id=cocoindex.GeneratedField.UUID, |
| 162 | + filename=doc["filename"], |
| 163 | + page=page["page_number"], |
| 164 | + image_data=image["data"], |
| 165 | + embedding=image["embedding"], |
| 166 | + ) |
| 167 | + |
| 168 | + text_output.export( |
| 169 | + "text_embeddings", |
| 170 | + cocoindex.targets.Qdrant( |
| 171 | + connection=qdrant_connection, |
| 172 | + collection_name=QDRANT_COLLECTION_TEXT, |
| 173 | + ), |
| 174 | + primary_key_fields=["id"], |
| 175 | + ) |
| 176 | + image_output.export( |
| 177 | + "image_embeddings", |
| 178 | + cocoindex.targets.Qdrant( |
| 179 | + connection=qdrant_connection, |
| 180 | + collection_name=QDRANT_COLLECTION_IMAGE, |
| 181 | + ), |
| 182 | + primary_key_fields=["id"], |
| 183 | + ) |
0 commit comments