Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion langgraph.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"kabuto_helpdesk_agent": "template_langgraph.agents.kabuto_helpdesk_agent.agent:graph",
"issue_formatter_agent": "template_langgraph.agents.issue_formatter_agent.agent:graph",
"task_decomposer_agent": "template_langgraph.agents.task_decomposer_agent.agent:graph",
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph"
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph",
"image_classifier_agent": "template_langgraph.agents.image_classifier_agent.agent:graph"
},
"env": ".env"
}
53 changes: 53 additions & 0 deletions scripts/agent_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dotenv import load_dotenv

from template_langgraph.agents.chat_with_tools_agent.agent import graph as chat_with_tools_agent_graph
from template_langgraph.agents.image_classifier_agent.agent import graph as image_classifier_agent_graph
from template_langgraph.agents.image_classifier_agent.models import Results
from template_langgraph.agents.issue_formatter_agent.agent import graph as issue_formatter_agent_graph
from template_langgraph.agents.kabuto_helpdesk_agent.agent import graph as kabuto_helpdesk_agent_graph
from template_langgraph.agents.news_summarizer_agent.agent import (
Expand Down Expand Up @@ -35,6 +37,8 @@ def get_agent_graph(name: str):
return kabuto_helpdesk_agent_graph
elif name == "news_summarizer_agent":
return news_summarizer_agent_graph
elif name == "image_classifier_agent":
return image_classifier_agent_graph
else:
raise ValueError(f"Unknown agent name: {name}")

Expand Down Expand Up @@ -165,6 +169,55 @@ def news_summarizer_agent(
logger.info(f"{article.structured_article.model_dump_json(indent=2)}")


@app.command()
def image_classifier_agent(
prompt: str = typer.Option(
"Please classify the image.",
"--prompt",
"-p",
help="Prompt for the agent",
),
file_paths: str = typer.Option(
"./docs/images/fastapi.png,./docs/images/jupyterlab.png",
"--file-paths",
"-f",
help="Comma-separated list of file paths to classify",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Enable verbose output",
),
):
from template_langgraph.agents.image_classifier_agent.models import (
AgentInputState,
AgentState,
)

# Set up logging
if verbose:
logger.setLevel(logging.DEBUG)

graph = image_classifier_agent_graph
for event in graph.stream(
input=AgentState(
input=AgentInputState(
prompt=prompt,
id=str(uuid4()),
file_paths=file_paths.split(",") if file_paths else [],
),
results=[],
)
):
logger.info("-" * 20)
logger.info(f"Event: {event}")

results: list[Results] = event["notify"]["results"]
for result in results:
logger.info(f"{result.model_dump_json(indent=2)}")


if __name__ == "__main__":
load_dotenv(
override=True,
Expand Down
63 changes: 63 additions & 0 deletions scripts/azure_openai_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from base64 import b64encode

import typer
from dotenv import load_dotenv
Expand All @@ -16,6 +17,11 @@
logger = get_logger(__name__)


def load_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return b64encode(image_file.read()).decode("utf-8")


@app.command()
def chat(
query: str = typer.Option(
Expand Down Expand Up @@ -80,6 +86,63 @@ def reasoning(
logger.info(f"Output: {response.content}")


@app.command()
def image(
query: str = typer.Option(
"Please analyze the following image and answer the question",
"--query",
"-q",
help="Query to run with the Azure OpenAI chat model",
),
file_path: str = typer.Option(
"./docs/images/streamlit.png",
"--file",
"-f",
help="Path to the image file to analyze",
),
verbose: bool = typer.Option(
False,
"--verbose",
"-v",
help="Enable verbose output",
),
):
# Set up logging
if verbose:
logger.setLevel(logging.DEBUG)

base64_image = load_image_to_base64(file_path)
messages = {
"role": "user",
"content": [
{
"type": "text",
"text": query,
},
{
"type": "image",
"source_type": "base64",
"data": base64_image,
"mime_type": "image/png",
},
],
}

logger.info("Running...")
response = AzureOpenAiWrapper().chat_model.invoke(
input=[
messages,
],
)
logger.debug(
response.model_dump_json(
indent=2,
exclude_none=True,
)
)
logger.info(f"Output: {response.content}")


if __name__ == "__main__":
load_dotenv(
override=True,
Expand Down
Empty file.
138 changes: 138 additions & 0 deletions template_langgraph/agents/image_classifier_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
from base64 import b64encode

import httpx
from langgraph.graph import StateGraph
from langgraph.types import Send

from template_langgraph.agents.image_classifier_agent.classifiers import (
BaseClassifier,
LlmClassifier,
MockClassifier,
)
from template_langgraph.agents.image_classifier_agent.models import (
AgentState,
ClassifyImageState,
Results,
)
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
from template_langgraph.loggers import get_logger

logger = get_logger(__name__)


def load_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return b64encode(image_file.read()).decode("utf-8")


class MockNotifier:
def notify(self, id: str, body: dict) -> None:
"""Simulate sending a notification to the user."""
logger.info(f"Notification sent for request {id}: {body}")


class ImageClassifierAgent:
def __init__(
self,
llm=AzureOpenAiWrapper().chat_model,
notifier=MockNotifier(),
classifier: BaseClassifier = MockClassifier(),
):
self.llm = llm
self.notifier = notifier
self.classifier: BaseClassifier = classifier

def create_graph(self):
"""Create the main graph for the agent."""
# Create the workflow state graph
workflow = StateGraph(AgentState)

# Create nodes
workflow.add_node("initialize", self.initialize)
workflow.add_node("classify_image", self.classify_image)
workflow.add_node("notify", self.notify)

# Create edges
workflow.set_entry_point("initialize")
workflow.add_conditional_edges(
source="initialize",
path=self.run_subtasks,
path_map={
"classify_image": "classify_image",
},
)
workflow.add_edge("classify_image", "notify")
workflow.set_finish_point("notify")
return workflow.compile(
name=ImageClassifierAgent.__name__,
)

def initialize(self, state: AgentState) -> AgentState:
"""Initialize the agent state."""
logger.info(f"Initializing state: {state}")
# FIXME: retrieve urls from user request
return state

def run_subtasks(self, state: AgentState) -> list[Send]:
"""Run the subtasks for the agent."""
logger.info(f"Running subtasks with state: {state}")
return [
Send(
node="classify_image",
arg=ClassifyImageState(
prompt=state.input.prompt,
file_path=state.input.file_paths[idx],
),
)
for idx, _ in enumerate(state.input.file_paths)
]

def classify_image(self, state: ClassifyImageState):
logger.info(f"Classify file: {state.file_path}")
if state.file_path.endswith((".png", ".jpg", ".jpeg")) and os.path.isfile(state.file_path):
try:
logger.info(f"Loading file: {state.file_path}")
base64_image = load_image_to_base64(state.file_path)

logger.info(f"Classifying file: {state.file_path}")
result = self.classifier.predict(
prompt=state.prompt,
image=base64_image,
llm=self.llm,
)

logger.info(f"Classification result: {result.model_dump_json(indent=2)}")
return {
"results": [
Results(
file_path=state.file_path,
result=result,
),
]
}
except httpx.RequestError as e:
logger.error(f"Error fetching web content: {e}")

def notify(self, state: AgentState) -> AgentState:
"""Send notifications to the user."""
logger.info(f"Sending notifications with state: {state}")
# Simulate sending notifications
summary = {}
for i, result in enumerate(state.results):
summary[i] = result.model_dump()
self.notifier.notify(
id=state.input.id,
body=summary,
)
return state


# For testing
# graph = ImageClassifierAgent().create_graph()

graph = ImageClassifierAgent(
llm=AzureOpenAiWrapper().chat_model,
notifier=MockNotifier(),
classifier=LlmClassifier(),
).create_graph()
80 changes: 80 additions & 0 deletions template_langgraph/agents/image_classifier_agent/classifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Classifier interfaces and implementations for ImageClassifierAgent.

This module defines an abstract base classifier interface so that different
image classification strategies (mock, LLM-backed, future vision models, etc.)
can be plugged into the agent without modifying the agent orchestration code.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from langchain_core.language_models.chat_models import BaseChatModel

from template_langgraph.agents.image_classifier_agent.models import Result
from template_langgraph.loggers import get_logger

logger = get_logger(__name__)


class BaseClassifier(ABC):
"""Abstract base class for image classifiers.

Implementations should return a structured ``Result`` object.
The ``llm`` argument is kept generic (Any) to avoid tight coupling
with a specific provider wrapper; callers supply a model instance
that offers the needed interface (e.g. ``with_structured_output``).
"""

@abstractmethod
def predict(self, prompt: str, image: str, llm: BaseChatModel) -> Result: # pragma: no cover - interface
"""Classify an image.

Args:
prompt: Instruction or question guiding the classification.
image: Base64-encoded image string ("data" portion only).
llm: A language / vision model instance used (if needed) by the classifier.

Returns:
Result: Structured classification output.
"""
raise NotImplementedError


class MockClassifier(BaseClassifier):
"""Simple mock classifier used for tests / offline development."""

def predict(self, prompt: str, image: str, llm: Any) -> Result: # noqa: D401
import time

time.sleep(3) # Simulate a long-running process
return Result(
title="Mocked Image Title",
summary=f"Mocked summary of the prompt: {prompt}",
labels=["mocked_label_1", "mocked_label_2"],
reliability=0.95,
)


class LlmClassifier(BaseClassifier):
"""LLM-backed classifier using the provided model's structured output capability."""

def predict(self, prompt: str, image: str, llm: BaseChatModel):
logger.info(f"Classifying image with LLM: {prompt}")
return llm.with_structured_output(Result).invoke(
input=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image",
"source_type": "base64",
"data": image,
"mime_type": "image/png",
},
],
}
]
)
Loading