diff --git a/core/services/document_service.py b/core/services/document_service.py index bbfe066f..70409845 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -366,7 +366,7 @@ async def batch_retrieve_chunks( auth: AuthContext, folder_name: Optional[str] = None, end_user_id: Optional[str] = None, - use_colpali: Optional[bool] = None, + retrieve_images: Optional[bool] = None, ) -> List[ChunkResult]: """ Retrieve specific chunks by their document ID and chunk number in a single batch operation. @@ -376,7 +376,7 @@ async def batch_retrieve_chunks( auth: Authentication context folder_name: Optional folder to scope the operation to end_user_id: Optional end-user ID to scope the operation to - use_colpali: Whether to use colpali multimodal features for image chunks + retrieve_images: Whether to use colpali multimodal features for image chunks Returns: List of ChunkResult objects @@ -404,7 +404,7 @@ async def batch_retrieve_chunks( retrieval_tasks = [self.vector_store.get_chunks_by_id(chunk_identifiers)] # Add colpali vector store task if needed - if use_colpali and self.colpali_vector_store: + if retrieve_images and self.colpali_vector_store: logger.info("Preparing to retrieve chunks from both regular and colpali vector stores") retrieval_tasks.append(self.colpali_vector_store.get_chunks_by_id(chunk_identifiers)) diff --git a/core/services/graph_service.py b/core/services/graph_service.py index 4ee852f2..167b97bb 100644 --- a/core/services/graph_service.py +++ b/core/services/graph_service.py @@ -530,8 +530,12 @@ async def _process_documents_for_entities( for i, _ in enumerate(doc.chunk_ids) ] - # Batch retrieve chunks - chunks = await document_service.batch_retrieve_chunks(chunk_sources, auth) + # Batch retrieve chunks, including image chunks when available + chunks = await document_service.batch_retrieve_chunks( + chunk_sources, + auth, + retrieve_images=True, + ) logger.info(f"Retrieved {len(chunks)} chunks for processing") # Process each chunk individually @@ -545,7 +549,7 @@ async def _process_documents_for_entities( # Extract entities and relationships from the chunk chunk_entities, chunk_relationships = await self.extract_entities_from_text( - chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides + chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides, override_is_image=chunk.metadata.get("is_image", False) ) # Store all initially extracted entities to track their IDs @@ -695,6 +699,7 @@ async def extract_entities_from_text( doc_id: str, chunk_number: int, prompt_overrides: Optional[EntityExtractionPromptOverride] = None, + override_is_image: Optional[bool] = None, ) -> Tuple[List[Entity], List[Relationship]]: """ Extract entities and relationships from text content using the LLM. @@ -703,17 +708,43 @@ async def extract_entities_from_text( content: Text content to process doc_id: Document ID chunk_number: Chunk number within the document + prompt_overrides: Optional EntityExtractionPromptOverride with customizations for prompts + override_is_image: Optional flag to override image detection based on metadata Returns: Tuple of (entities, relationships) """ settings = get_settings() - # Limit text length to avoid token limits - content_limited = content[: min(len(content), 5000)] - - # We'll use the Pydantic model directly when calling litellm - # No need to generate JSON schema separately + # Determine if this chunk is an image based solely on metadata flag + is_image = override_is_image if override_is_image is not None else False + # For images, send full base64 content; for text, truncate to limit + if is_image: + content_limited = content + else: + content_limited = content[: min(len(content), 5000)] + + # Build system message, differentiating between text and image inputs + if is_image: + # For images, instruct comprehensive visual and text interpretation + system_content = ( + "You are a multi-modal extraction assistant for images. The input is a base64-encoded PNG of a document page. " + "Perform OCR to extract the text, and visually interpret all layout elements: tables, diagrams, charts, form fields, headings, and graphical icons. " + "Identify entities from both text and visuals, infer relationships depicted (e.g., flows, hierarchies, links), and draw logical conclusions from the combined information. " + "For entities, include their label and type (e.g., PERSON, ORGANIZATION, LOCATION, CONCEPT). " + "For relationships, output a JSON list of objects with source, target, and relationship fields. " + "Respond only with valid JSON representing the extracted entities and relationships." + ) + else: + # For text, use standard extraction instructions + system_content = ( + "You are an entity extraction and relationship extraction assistant. " + "Extract entities and their relationships from the input text precisely and thoroughly. " + "For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " + "For relationships, use a simple JSON format with source, target, and relationship fields. " + "Respond directly in valid JSON format, without any additional text or explanations." + ) + system_message = {"role": "system", "content": system_content} # Get entity extraction overrides if available extraction_overrides = {} @@ -744,45 +775,47 @@ async def extract_entities_from_text( f"{json.dumps(examples_json, indent=2)}\n```\n" ) - # Modify the system message to handle properties as a string that will be parsed later - system_message = { - "role": "system", - "content": ( - "You are an entity extraction and relationship extraction assistant. Extract entities and " - "their relationships from text precisely and thoroughly, extract as many entities and " - "relationships as possible. " - "For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, " - "CONCEPT, etc.). If the user has given examples, use those, these are just suggestions" - "For relationships, use a simple format with source, target, and relationship fields. " - "Be very through, there are many relationships that are not obvious" - "IMPORTANT: The source and target fields must be simple strings representing " - "entity labels. For example: " - "if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', " - "target: 'Entity B', relationship: 'relates to'. " - "Respond directly in json format, without any additional text or explanations. " - ), - } + # Construct user message content as a list of content blocks + user_message_content = [] + + if is_image: + # For images, add the image as a content block + user_message_content.append({"type": "image_url", "image_url": {"url": content_limited}}) + # Add image-specific instructions as a text block + image_instructions = ( + "Extract named entities and their relationships from the following image. " + "Perform OCR and visually interpret all layout elements. " + "Return your response as valid JSON.\n\n" + ) + user_message_content.append({"type": "text", "text": image_instructions}) + else: + # For text, add the text content block + text_instructions = ( + "Extract named entities and their relationships from the following text. " + "For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " + "For relationships, specify the source entity, target entity, and the relationship between them. " + 'Sample relationship format: {"source": "Entity A", "target": "Entity B", ' + '"relationship": "works for"}\n\n' + "Return your response as valid JSON:\n\n" + ) + user_message_content.append({"type": "text", "text": text_instructions}) + user_message_content.append({"type": "text", "text": content_limited}) - # Use custom prompt if provided, otherwise use default + # Add examples if provided and not using a custom prompt template that handles them + if examples_str and not custom_prompt: + user_message_content.append({"type": "text", "text": examples_str}) + + # Use custom prompt template if provided if custom_prompt: - user_message = { - "role": "user", - "content": custom_prompt.format(content=content_limited, examples=examples_str), - } + # If a custom prompt is provided, it takes precedence and formats the content + # We assume the custom prompt handles incorporating text and image content appropriately. + # For simplicity, we'll just pass the original content_limited and examples_str + # to the custom prompt formatter. The user is responsible for formatting in the template. + formatted_user_text = custom_prompt.format(content=content_limited, examples=examples_str) + user_message = {"role": "user", "content": formatted_user_text} else: - user_message = { - "role": "user", - "content": ( - "Extract named entities and their relationships from the following text. " - "For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " - "For relationships, specify the source entity, target entity, and the relationship between them. " - "The source and target must be simple strings matching the entity labels, not objects. " - f"{examples_str}" - 'Sample relationship format: {"source": "Entity A", "target": "Entity B", ' - '"relationship": "works for"}\n\n' - "Return your response as valid JSON:\n\n" + content_limited - ), - } + # Otherwise, use the constructed list of content blocks + user_message = {"role": "user", "content": user_message_content} # Get the model configuration from registered_models model_config = settings.REGISTERED_MODELS.get(settings.GRAPH_MODEL, {}) diff --git a/core/tests/unit/test_graph_service_image_extraction.py b/core/tests/unit/test_graph_service_image_extraction.py new file mode 100644 index 00000000..57cc060c --- /dev/null +++ b/core/tests/unit/test_graph_service_image_extraction.py @@ -0,0 +1,72 @@ +import asyncio +import pytest + +from core.services.graph_service import GraphService, ExtractionResult, EntityExtraction, RelationshipExtraction + +# Dummy settings for testing +class DummySettings: + GRAPH_MODEL = "dummy" + REGISTERED_MODELS = {"dummy": {"model_name": "test-model"}} + +# Dummy instructor client to capture messages and return a simple ExtractionResult +class DummyClient: + def __init__(self): + self.captured_messages = None + self.chat = self + self.completions = self + + async def create(self, model, messages, response_model, **kwargs): + self.captured_messages = messages + return ExtractionResult( + entities=[EntityExtraction(label="TestEntity", type="CONCEPT")], + relationships=[RelationshipExtraction(source="TestEntity", target="TestEntity", relationship="related_to")] + ) + +@pytest.fixture(autouse=True) +def patch_settings_and_instructor(monkeypatch): + # Patch get_settings to return DummySettings + import core.services.graph_service as gs_mod + monkeypatch.setattr(gs_mod, "get_settings", lambda: DummySettings()) + # Prepare dummy instructor and litellm modules for dynamic import + import sys, types + dummy_client = DummyClient() + dummy_instructor = types.SimpleNamespace( + from_litellm=lambda ac, mode: dummy_client, + Mode=types.SimpleNamespace(JSON=None) + ) + dummy_litellm = types.SimpleNamespace(acompletion='dummy') + # Insert into sys.modules so that import instructor/litellm picks up our dummy + monkeypatch.setitem(sys.modules, 'instructor', dummy_instructor) + monkeypatch.setitem(sys.modules, 'litellm', dummy_litellm) + return dummy_client + +@pytest.mark.parametrize("content,expected_system_prefix", [ + ('data:image/png;base64,AAA', 'You are an entity extraction and relationship extraction assistant for images.'), + ('Plain text content.', 'You are an entity extraction and relationship extraction assistant.'), +]) +def test_system_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_system_prefix): + service = GraphService(db=None, embedding_model=None, completion_model=None) + entities, relationships = asyncio.run( + service.extract_entities_from_text(content, doc_id="doc1", chunk_number=0) + ) + dummy = patch_settings_and_instructor + assert dummy.captured_messages is not None + system_msg, _ = dummy.captured_messages + assert system_msg['content'].startswith(expected_system_prefix) + assert entities and entities[0].label == "TestEntity" + assert relationships and relationships[0].type == "related_to" + +@pytest.mark.parametrize("content,expected_user_prefix", [ + ('data:image/png;base64,BBB', 'Extract named entities and their relationships from the following base64-encoded image.'), + ('Hello world', 'Extract named entities and their relationships from the following text.'), +]) +def test_user_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_user_prefix): + service = GraphService(db=None, embedding_model=None, completion_model=None) + entities, relationships = asyncio.run( + service.extract_entities_from_text(content, doc_id="doc2", chunk_number=1) + ) + dummy = patch_settings_and_instructor + _, user_msg = dummy.captured_messages + assert user_msg['content'].startswith(expected_user_prefix) + # Validate stub relationship + assert relationships and relationships[0].type == "related_to" \ No newline at end of file