|
| 1 | +"""Classifier interfaces and implementations for ImageClassifierAgent. |
| 2 | +
|
| 3 | +This module defines an abstract base classifier interface so that different |
| 4 | +image classification strategies (mock, LLM-backed, future vision models, etc.) |
| 5 | +can be plugged into the agent without modifying the agent orchestration code. |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +from abc import ABC, abstractmethod |
| 11 | +from typing import Any |
| 12 | + |
| 13 | +from langchain_core.language_models.chat_models import BaseChatModel |
| 14 | + |
| 15 | +from template_langgraph.agents.image_classifier_agent.models import Result |
| 16 | +from template_langgraph.loggers import get_logger |
| 17 | + |
| 18 | +logger = get_logger(__name__) |
| 19 | + |
| 20 | + |
| 21 | +class BaseClassifier(ABC): |
| 22 | + """Abstract base class for image classifiers. |
| 23 | +
|
| 24 | + Implementations should return a structured ``Result`` object. |
| 25 | + The ``llm`` argument is kept generic (Any) to avoid tight coupling |
| 26 | + with a specific provider wrapper; callers supply a model instance |
| 27 | + that offers the needed interface (e.g. ``with_structured_output``). |
| 28 | + """ |
| 29 | + |
| 30 | + @abstractmethod |
| 31 | + def predict(self, prompt: str, image: str, llm: BaseChatModel) -> Result: # pragma: no cover - interface |
| 32 | + """Classify an image. |
| 33 | +
|
| 34 | + Args: |
| 35 | + prompt: Instruction or question guiding the classification. |
| 36 | + image: Base64-encoded image string ("data" portion only). |
| 37 | + llm: A language / vision model instance used (if needed) by the classifier. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + Result: Structured classification output. |
| 41 | + """ |
| 42 | + raise NotImplementedError |
| 43 | + |
| 44 | + |
| 45 | +class MockClassifier(BaseClassifier): |
| 46 | + """Simple mock classifier used for tests / offline development.""" |
| 47 | + |
| 48 | + def predict(self, prompt: str, image: str, llm: Any) -> Result: # noqa: D401 |
| 49 | + import time |
| 50 | + |
| 51 | + time.sleep(3) # Simulate a long-running process |
| 52 | + return Result( |
| 53 | + title="Mocked Image Title", |
| 54 | + summary=f"Mocked summary of the prompt: {prompt}", |
| 55 | + labels=["mocked_label_1", "mocked_label_2"], |
| 56 | + reliability=0.95, |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +class LlmClassifier(BaseClassifier): |
| 61 | + """LLM-backed classifier using the provided model's structured output capability.""" |
| 62 | + |
| 63 | + def predict(self, prompt: str, image: str, llm: BaseChatModel) -> Result: # noqa: D401 |
| 64 | + logger.info(f"Classifying image with LLM: {prompt}") |
| 65 | + return llm.with_structured_output(Result).invoke( |
| 66 | + input=[ |
| 67 | + { |
| 68 | + "role": "user", |
| 69 | + "content": [ |
| 70 | + {"type": "text", "text": prompt}, |
| 71 | + { |
| 72 | + "type": "image", |
| 73 | + "source_type": "base64", |
| 74 | + "data": image, |
| 75 | + "mime_type": "image/png", |
| 76 | + }, |
| 77 | + ], |
| 78 | + } |
| 79 | + ] |
| 80 | + ) |
0 commit comments