Skip to content

Commit dc25fed

Browse files
authored
Merge pull request #74 from ks6088ts-labs/feature/issue-73_classifier-agent
image classifier agent
2 parents 53631f8 + 8e99e5d commit dc25fed

File tree

10 files changed

+503
-55
lines changed

10 files changed

+503
-55
lines changed

langgraph.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"kabuto_helpdesk_agent": "template_langgraph.agents.kabuto_helpdesk_agent.agent:graph",
99
"issue_formatter_agent": "template_langgraph.agents.issue_formatter_agent.agent:graph",
1010
"task_decomposer_agent": "template_langgraph.agents.task_decomposer_agent.agent:graph",
11-
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph"
11+
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph",
12+
"image_classifier_agent": "template_langgraph.agents.image_classifier_agent.agent:graph"
1213
},
1314
"env": ".env"
1415
}

scripts/agent_operator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dotenv import load_dotenv
66

77
from template_langgraph.agents.chat_with_tools_agent.agent import graph as chat_with_tools_agent_graph
8+
from template_langgraph.agents.image_classifier_agent.agent import graph as image_classifier_agent_graph
9+
from template_langgraph.agents.image_classifier_agent.models import Results
810
from template_langgraph.agents.issue_formatter_agent.agent import graph as issue_formatter_agent_graph
911
from template_langgraph.agents.kabuto_helpdesk_agent.agent import graph as kabuto_helpdesk_agent_graph
1012
from template_langgraph.agents.news_summarizer_agent.agent import (
@@ -35,6 +37,8 @@ def get_agent_graph(name: str):
3537
return kabuto_helpdesk_agent_graph
3638
elif name == "news_summarizer_agent":
3739
return news_summarizer_agent_graph
40+
elif name == "image_classifier_agent":
41+
return image_classifier_agent_graph
3842
else:
3943
raise ValueError(f"Unknown agent name: {name}")
4044

@@ -165,6 +169,55 @@ def news_summarizer_agent(
165169
logger.info(f"{article.structured_article.model_dump_json(indent=2)}")
166170

167171

172+
@app.command()
173+
def image_classifier_agent(
174+
prompt: str = typer.Option(
175+
"Please classify the image.",
176+
"--prompt",
177+
"-p",
178+
help="Prompt for the agent",
179+
),
180+
file_paths: str = typer.Option(
181+
"./docs/images/fastapi.png,./docs/images/jupyterlab.png",
182+
"--file-paths",
183+
"-f",
184+
help="Comma-separated list of file paths to classify",
185+
),
186+
verbose: bool = typer.Option(
187+
False,
188+
"--verbose",
189+
"-v",
190+
help="Enable verbose output",
191+
),
192+
):
193+
from template_langgraph.agents.image_classifier_agent.models import (
194+
AgentInputState,
195+
AgentState,
196+
)
197+
198+
# Set up logging
199+
if verbose:
200+
logger.setLevel(logging.DEBUG)
201+
202+
graph = image_classifier_agent_graph
203+
for event in graph.stream(
204+
input=AgentState(
205+
input=AgentInputState(
206+
prompt=prompt,
207+
id=str(uuid4()),
208+
file_paths=file_paths.split(",") if file_paths else [],
209+
),
210+
results=[],
211+
)
212+
):
213+
logger.info("-" * 20)
214+
logger.info(f"Event: {event}")
215+
216+
results: list[Results] = event["notify"]["results"]
217+
for result in results:
218+
logger.info(f"{result.model_dump_json(indent=2)}")
219+
220+
168221
if __name__ == "__main__":
169222
load_dotenv(
170223
override=True,

scripts/azure_openai_operator.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from base64 import b64encode
23

34
import typer
45
from dotenv import load_dotenv
@@ -16,6 +17,11 @@
1617
logger = get_logger(__name__)
1718

1819

20+
def load_image_to_base64(image_path: str) -> str:
21+
with open(image_path, "rb") as image_file:
22+
return b64encode(image_file.read()).decode("utf-8")
23+
24+
1925
@app.command()
2026
def chat(
2127
query: str = typer.Option(
@@ -80,6 +86,63 @@ def reasoning(
8086
logger.info(f"Output: {response.content}")
8187

8288

89+
@app.command()
90+
def image(
91+
query: str = typer.Option(
92+
"Please analyze the following image and answer the question",
93+
"--query",
94+
"-q",
95+
help="Query to run with the Azure OpenAI chat model",
96+
),
97+
file_path: str = typer.Option(
98+
"./docs/images/streamlit.png",
99+
"--file",
100+
"-f",
101+
help="Path to the image file to analyze",
102+
),
103+
verbose: bool = typer.Option(
104+
False,
105+
"--verbose",
106+
"-v",
107+
help="Enable verbose output",
108+
),
109+
):
110+
# Set up logging
111+
if verbose:
112+
logger.setLevel(logging.DEBUG)
113+
114+
base64_image = load_image_to_base64(file_path)
115+
messages = {
116+
"role": "user",
117+
"content": [
118+
{
119+
"type": "text",
120+
"text": query,
121+
},
122+
{
123+
"type": "image",
124+
"source_type": "base64",
125+
"data": base64_image,
126+
"mime_type": "image/png",
127+
},
128+
],
129+
}
130+
131+
logger.info("Running...")
132+
response = AzureOpenAiWrapper().chat_model.invoke(
133+
input=[
134+
messages,
135+
],
136+
)
137+
logger.debug(
138+
response.model_dump_json(
139+
indent=2,
140+
exclude_none=True,
141+
)
142+
)
143+
logger.info(f"Output: {response.content}")
144+
145+
83146
if __name__ == "__main__":
84147
load_dotenv(
85148
override=True,

template_langgraph/agents/image_classifier_agent/__init__.py

Whitespace-only changes.
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os
2+
from base64 import b64encode
3+
4+
import httpx
5+
from langgraph.graph import StateGraph
6+
from langgraph.types import Send
7+
8+
from template_langgraph.agents.image_classifier_agent.classifiers import (
9+
BaseClassifier,
10+
LlmClassifier,
11+
MockClassifier,
12+
)
13+
from template_langgraph.agents.image_classifier_agent.models import (
14+
AgentState,
15+
ClassifyImageState,
16+
Results,
17+
)
18+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
19+
from template_langgraph.loggers import get_logger
20+
21+
logger = get_logger(__name__)
22+
23+
24+
def load_image_to_base64(image_path: str) -> str:
25+
with open(image_path, "rb") as image_file:
26+
return b64encode(image_file.read()).decode("utf-8")
27+
28+
29+
class MockNotifier:
30+
def notify(self, id: str, body: dict) -> None:
31+
"""Simulate sending a notification to the user."""
32+
logger.info(f"Notification sent for request {id}: {body}")
33+
34+
35+
class ImageClassifierAgent:
36+
def __init__(
37+
self,
38+
llm=AzureOpenAiWrapper().chat_model,
39+
notifier=MockNotifier(),
40+
classifier: BaseClassifier = MockClassifier(),
41+
):
42+
self.llm = llm
43+
self.notifier = notifier
44+
self.classifier: BaseClassifier = classifier
45+
46+
def create_graph(self):
47+
"""Create the main graph for the agent."""
48+
# Create the workflow state graph
49+
workflow = StateGraph(AgentState)
50+
51+
# Create nodes
52+
workflow.add_node("initialize", self.initialize)
53+
workflow.add_node("classify_image", self.classify_image)
54+
workflow.add_node("notify", self.notify)
55+
56+
# Create edges
57+
workflow.set_entry_point("initialize")
58+
workflow.add_conditional_edges(
59+
source="initialize",
60+
path=self.run_subtasks,
61+
path_map={
62+
"classify_image": "classify_image",
63+
},
64+
)
65+
workflow.add_edge("classify_image", "notify")
66+
workflow.set_finish_point("notify")
67+
return workflow.compile(
68+
name=ImageClassifierAgent.__name__,
69+
)
70+
71+
def initialize(self, state: AgentState) -> AgentState:
72+
"""Initialize the agent state."""
73+
logger.info(f"Initializing state: {state}")
74+
# FIXME: retrieve urls from user request
75+
return state
76+
77+
def run_subtasks(self, state: AgentState) -> list[Send]:
78+
"""Run the subtasks for the agent."""
79+
logger.info(f"Running subtasks with state: {state}")
80+
return [
81+
Send(
82+
node="classify_image",
83+
arg=ClassifyImageState(
84+
prompt=state.input.prompt,
85+
file_path=state.input.file_paths[idx],
86+
),
87+
)
88+
for idx, _ in enumerate(state.input.file_paths)
89+
]
90+
91+
def classify_image(self, state: ClassifyImageState):
92+
logger.info(f"Classify file: {state.file_path}")
93+
if state.file_path.endswith((".png", ".jpg", ".jpeg")) and os.path.isfile(state.file_path):
94+
try:
95+
logger.info(f"Loading file: {state.file_path}")
96+
base64_image = load_image_to_base64(state.file_path)
97+
98+
logger.info(f"Classifying file: {state.file_path}")
99+
result = self.classifier.predict(
100+
prompt=state.prompt,
101+
image=base64_image,
102+
llm=self.llm,
103+
)
104+
105+
logger.info(f"Classification result: {result.model_dump_json(indent=2)}")
106+
return {
107+
"results": [
108+
Results(
109+
file_path=state.file_path,
110+
result=result,
111+
),
112+
]
113+
}
114+
except httpx.RequestError as e:
115+
logger.error(f"Error fetching web content: {e}")
116+
117+
def notify(self, state: AgentState) -> AgentState:
118+
"""Send notifications to the user."""
119+
logger.info(f"Sending notifications with state: {state}")
120+
# Simulate sending notifications
121+
summary = {}
122+
for i, result in enumerate(state.results):
123+
summary[i] = result.model_dump()
124+
self.notifier.notify(
125+
id=state.input.id,
126+
body=summary,
127+
)
128+
return state
129+
130+
131+
# For testing
132+
# graph = ImageClassifierAgent().create_graph()
133+
134+
graph = ImageClassifierAgent(
135+
llm=AzureOpenAiWrapper().chat_model,
136+
notifier=MockNotifier(),
137+
classifier=LlmClassifier(),
138+
).create_graph()
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Classifier interfaces and implementations for ImageClassifierAgent.
2+
3+
This module defines an abstract base classifier interface so that different
4+
image classification strategies (mock, LLM-backed, future vision models, etc.)
5+
can be plugged into the agent without modifying the agent orchestration code.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from abc import ABC, abstractmethod
11+
from typing import Any
12+
13+
from langchain_core.language_models.chat_models import BaseChatModel
14+
15+
from template_langgraph.agents.image_classifier_agent.models import Result
16+
from template_langgraph.loggers import get_logger
17+
18+
logger = get_logger(__name__)
19+
20+
21+
class BaseClassifier(ABC):
22+
"""Abstract base class for image classifiers.
23+
24+
Implementations should return a structured ``Result`` object.
25+
The ``llm`` argument is kept generic (Any) to avoid tight coupling
26+
with a specific provider wrapper; callers supply a model instance
27+
that offers the needed interface (e.g. ``with_structured_output``).
28+
"""
29+
30+
@abstractmethod
31+
def predict(self, prompt: str, image: str, llm: BaseChatModel) -> Result: # pragma: no cover - interface
32+
"""Classify an image.
33+
34+
Args:
35+
prompt: Instruction or question guiding the classification.
36+
image: Base64-encoded image string ("data" portion only).
37+
llm: A language / vision model instance used (if needed) by the classifier.
38+
39+
Returns:
40+
Result: Structured classification output.
41+
"""
42+
raise NotImplementedError
43+
44+
45+
class MockClassifier(BaseClassifier):
46+
"""Simple mock classifier used for tests / offline development."""
47+
48+
def predict(self, prompt: str, image: str, llm: Any) -> Result: # noqa: D401
49+
import time
50+
51+
time.sleep(3) # Simulate a long-running process
52+
return Result(
53+
title="Mocked Image Title",
54+
summary=f"Mocked summary of the prompt: {prompt}",
55+
labels=["mocked_label_1", "mocked_label_2"],
56+
reliability=0.95,
57+
)
58+
59+
60+
class LlmClassifier(BaseClassifier):
61+
"""LLM-backed classifier using the provided model's structured output capability."""
62+
63+
def predict(self, prompt: str, image: str, llm: BaseChatModel):
64+
logger.info(f"Classifying image with LLM: {prompt}")
65+
return llm.with_structured_output(Result).invoke(
66+
input=[
67+
{
68+
"role": "user",
69+
"content": [
70+
{"type": "text", "text": prompt},
71+
{
72+
"type": "image",
73+
"source_type": "base64",
74+
"data": image,
75+
"mime_type": "image/png",
76+
},
77+
],
78+
}
79+
]
80+
)

0 commit comments

Comments
 (0)