Skip to content

Commit 38d5d0e

Browse files
committed
add image classifier agent
1 parent 350d5c1 commit 38d5d0e

File tree

5 files changed

+267
-1
lines changed

5 files changed

+267
-1
lines changed

langgraph.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
"kabuto_helpdesk_agent": "template_langgraph.agents.kabuto_helpdesk_agent.agent:graph",
99
"issue_formatter_agent": "template_langgraph.agents.issue_formatter_agent.agent:graph",
1010
"task_decomposer_agent": "template_langgraph.agents.task_decomposer_agent.agent:graph",
11-
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph"
11+
"news_summarizer_agent": "template_langgraph.agents.news_summarizer_agent.agent:graph",
12+
"image_classifier_agent": "template_langgraph.agents.image_classifier_agent.agent:graph"
1213
},
1314
"env": ".env"
1415
}

scripts/agent_operator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from dotenv import load_dotenv
66

77
from template_langgraph.agents.chat_with_tools_agent.agent import graph as chat_with_tools_agent_graph
8+
from template_langgraph.agents.image_classifier_agent.agent import graph as image_classifier_agent_graph
9+
from template_langgraph.agents.image_classifier_agent.models import Results
810
from template_langgraph.agents.issue_formatter_agent.agent import graph as issue_formatter_agent_graph
911
from template_langgraph.agents.kabuto_helpdesk_agent.agent import graph as kabuto_helpdesk_agent_graph
1012
from template_langgraph.agents.news_summarizer_agent.agent import (
@@ -35,6 +37,8 @@ def get_agent_graph(name: str):
3537
return kabuto_helpdesk_agent_graph
3638
elif name == "news_summarizer_agent":
3739
return news_summarizer_agent_graph
40+
elif name == "image_classifier_agent":
41+
return image_classifier_agent_graph
3842
else:
3943
raise ValueError(f"Unknown agent name: {name}")
4044

@@ -165,6 +169,55 @@ def news_summarizer_agent(
165169
logger.info(f"{article.structured_article.model_dump_json(indent=2)}")
166170

167171

172+
@app.command()
173+
def image_classifier_agent(
174+
prompt: str = typer.Option(
175+
"Please classify the image.",
176+
"--prompt",
177+
"-p",
178+
help="Prompt for the agent",
179+
),
180+
file_paths: str = typer.Option(
181+
"./docs/images/fastapi.png,./docs/images/jupyterlab.png",
182+
"--file-paths",
183+
"-f",
184+
help="Comma-separated list of file paths to classify",
185+
),
186+
verbose: bool = typer.Option(
187+
False,
188+
"--verbose",
189+
"-v",
190+
help="Enable verbose output",
191+
),
192+
):
193+
from template_langgraph.agents.image_classifier_agent.models import (
194+
AgentInputState,
195+
AgentState,
196+
)
197+
198+
# Set up logging
199+
if verbose:
200+
logger.setLevel(logging.DEBUG)
201+
202+
graph = image_classifier_agent_graph
203+
for event in graph.stream(
204+
input=AgentState(
205+
input=AgentInputState(
206+
prompt=prompt,
207+
id=str(uuid4()),
208+
file_paths=file_paths.split(",") if file_paths else [],
209+
),
210+
results=[],
211+
)
212+
):
213+
logger.info("-" * 20)
214+
logger.info(f"Event: {event}")
215+
216+
results: list[Results] = event["notify"]["results"]
217+
for result in results:
218+
logger.info(f"{result.model_dump_json(indent=2)}")
219+
220+
168221
if __name__ == "__main__":
169222
load_dotenv(
170223
override=True,

template_langgraph/agents/image_classifier_agent/__init__.py

Whitespace-only changes.
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import os
2+
from base64 import b64encode
3+
4+
import httpx
5+
from langgraph.graph import StateGraph
6+
from langgraph.types import Send
7+
8+
from template_langgraph.agents.image_classifier_agent.models import (
9+
AgentState,
10+
ClassifyImageState,
11+
Result,
12+
Results,
13+
)
14+
from template_langgraph.llms.azure_openais import AzureOpenAiWrapper
15+
from template_langgraph.loggers import get_logger
16+
17+
logger = get_logger(__name__)
18+
19+
20+
def load_image_to_base64(image_path: str) -> str:
21+
with open(image_path, "rb") as image_file:
22+
return b64encode(image_file.read()).decode("utf-8")
23+
24+
25+
class MockNotifier:
26+
def notify(self, id: str, body: dict) -> None:
27+
"""Simulate sending a notification to the user."""
28+
logger.info(f"Notification sent for request {id}: {body}")
29+
30+
31+
class MockClassifier:
32+
def predict(
33+
self,
34+
prompt: str,
35+
image: str,
36+
llm=AzureOpenAiWrapper().chat_model,
37+
) -> Result:
38+
"""Simulate image classification."""
39+
return Result(
40+
title="Mocked Image Title",
41+
summary=f"Mocked summary of the prompt: {prompt}",
42+
labels=["mocked_label_1", "mocked_label_2"],
43+
reliability=0.95,
44+
)
45+
46+
47+
class LlmClassifier:
48+
def predict(
49+
self,
50+
prompt: str,
51+
image: str,
52+
llm=AzureOpenAiWrapper().chat_model,
53+
) -> Result:
54+
"""Use the LLM to classify the image."""
55+
logger.info(f"Classifying image with LLM: {prompt}")
56+
return llm.with_structured_output(Result).invoke(
57+
input=[
58+
{
59+
"role": "user",
60+
"content": [
61+
{
62+
"type": "text",
63+
"text": prompt,
64+
},
65+
{
66+
"type": "image",
67+
"source_type": "base64",
68+
"data": image,
69+
"mime_type": "image/png",
70+
},
71+
],
72+
},
73+
]
74+
)
75+
76+
77+
class ImageClassifierAgent:
78+
def __init__(
79+
self,
80+
llm=AzureOpenAiWrapper().chat_model,
81+
notifier=MockNotifier(),
82+
classifier=MockClassifier(),
83+
):
84+
self.llm = llm
85+
self.notifier = notifier
86+
self.classifier = classifier
87+
88+
def create_graph(self):
89+
"""Create the main graph for the agent."""
90+
# Create the workflow state graph
91+
workflow = StateGraph(AgentState)
92+
93+
# Create nodes
94+
workflow.add_node("initialize", self.initialize)
95+
workflow.add_node("classify_image", self.classify_image)
96+
workflow.add_node("notify", self.notify)
97+
98+
# Create edges
99+
workflow.set_entry_point("initialize")
100+
workflow.add_conditional_edges(
101+
source="initialize",
102+
path=self.run_subtasks,
103+
path_map={
104+
"classify_image": "classify_image",
105+
},
106+
)
107+
workflow.add_edge("classify_image", "notify")
108+
workflow.set_finish_point("notify")
109+
return workflow.compile(
110+
name=ImageClassifierAgent.__name__,
111+
)
112+
113+
def initialize(self, state: AgentState) -> AgentState:
114+
"""Initialize the agent state."""
115+
logger.info(f"Initializing state: {state}")
116+
# FIXME: retrieve urls from user request
117+
return state
118+
119+
def run_subtasks(self, state: AgentState) -> list[Send]:
120+
"""Run the subtasks for the agent."""
121+
logger.info(f"Running subtasks with state: {state}")
122+
return [
123+
Send(
124+
node="classify_image",
125+
arg=ClassifyImageState(
126+
prompt=state.input.prompt,
127+
file_path=state.input.file_paths[idx],
128+
),
129+
)
130+
for idx, _ in enumerate(state.input.file_paths)
131+
]
132+
133+
def classify_image(self, state: ClassifyImageState):
134+
logger.info(f"Classify file: {state.file_path}")
135+
if state.file_path.endswith((".png", ".jpg", ".jpeg")) and os.path.isfile(state.file_path):
136+
try:
137+
logger.info(f"Loading file: {state.file_path}")
138+
base64_image = load_image_to_base64(state.file_path)
139+
140+
logger.info(f"Classifying file: {state.file_path}")
141+
result = self.classifier.predict(
142+
prompt=state.prompt,
143+
image=base64_image,
144+
llm=self.llm,
145+
)
146+
147+
logger.info(f"Classification result: {result.model_dump_json(indent=2)}")
148+
return {
149+
"results": [
150+
Results(
151+
file_path=state.file_path,
152+
result=result,
153+
),
154+
]
155+
}
156+
except httpx.RequestError as e:
157+
logger.error(f"Error fetching web content: {e}")
158+
159+
def notify(self, state: AgentState) -> AgentState:
160+
"""Send notifications to the user."""
161+
logger.info(f"Sending notifications with state: {state}")
162+
# Simulate sending notifications
163+
summary = {}
164+
for i, result in enumerate(state.results):
165+
summary[i] = result.model_dump()
166+
self.notifier.notify(
167+
id=state.input.id,
168+
body=summary,
169+
)
170+
return state
171+
172+
173+
# For testing
174+
# graph = ImageClassifierAgent().create_graph()
175+
176+
graph = ImageClassifierAgent(
177+
llm=AzureOpenAiWrapper().chat_model,
178+
notifier=MockNotifier(),
179+
classifier=LlmClassifier(),
180+
).create_graph()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import operator
2+
from typing import Annotated
3+
4+
from pydantic import BaseModel, Field
5+
6+
7+
class ClassifyImageState(BaseModel):
8+
prompt: str = Field(..., description="Prompt for classification")
9+
file_path: str = Field(..., description="Image file path")
10+
11+
12+
class Result(BaseModel):
13+
title: str = Field(..., description="Title of the image")
14+
summary: str = Field(..., description="Summary of the image")
15+
labels: list[str] = Field(..., description="Labels extracted from the image")
16+
reliability: float = Field(..., description="Reliability score of the classification from 0 to 1")
17+
18+
19+
class Results(BaseModel):
20+
file_path: str = Field(..., description="Image file path")
21+
result: Result = Field(..., description="Structured representation of the image classification result")
22+
23+
24+
class AgentInputState(BaseModel):
25+
prompt: str = Field(..., description="Prompt for the agent")
26+
id: str = Field(..., description="Unique identifier for the request")
27+
file_paths: list[str] = Field(..., description="List of image file paths")
28+
29+
30+
class AgentState(BaseModel):
31+
input: AgentInputState = Field(..., description="Input state for the agent")
32+
results: Annotated[list[Results], operator.add]

0 commit comments

Comments
 (0)