Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.

Commit 83892cd

Browse files
Allow setting custom LLM for the vision tool (#294)
* Allow setting custom LLM for the vision tool Defaults to gpt-4o-mini otherwise * Enhance VisionTool with model management and improved initialization - Added support for setting a custom model identifier with a default of "gpt-4o-mini". - Introduced properties for model management, allowing dynamic updates and resetting of the LLM instance. - Updated the initialization method to accept an optional LLM and model parameter. - Refactored the image processing logic for clarity and efficiency. * docstrings * Add stop config --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
1 parent a819296 commit 83892cd

File tree

1 file changed

+52
-16
lines changed

1 file changed

+52
-16
lines changed

crewai_tools/tools/vision_tool/vision_tool.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from pathlib import Path
33
from typing import Optional, Type
44

5+
from crewai import LLM
56
from crewai.tools import BaseTool
6-
from openai import OpenAI
7-
from pydantic import BaseModel, field_validator
7+
from pydantic import BaseModel, PrivateAttr, field_validator
88

99

1010
class ImagePromptSchema(BaseModel):
@@ -32,27 +32,59 @@ def validate_image_path_url(cls, v: str) -> str:
3232

3333

3434
class VisionTool(BaseTool):
35+
"""Tool for analyzing images using vision models.
36+
37+
Args:
38+
llm: Optional LLM instance to use
39+
model: Model identifier to use if no LLM is provided
40+
"""
41+
3542
name: str = "Vision Tool"
3643
description: str = (
3744
"This tool uses OpenAI's Vision API to describe the contents of an image."
3845
)
3946
args_schema: Type[BaseModel] = ImagePromptSchema
40-
_client: Optional[OpenAI] = None
47+
48+
_model: str = PrivateAttr(default="gpt-4o-mini")
49+
_llm: Optional[LLM] = PrivateAttr(default=None)
50+
51+
def __init__(self, llm: Optional[LLM] = None, model: str = "gpt-4o-mini", **kwargs):
52+
"""Initialize the vision tool.
53+
54+
Args:
55+
llm: Optional LLM instance to use
56+
model: Model identifier to use if no LLM is provided
57+
**kwargs: Additional arguments for the base tool
58+
"""
59+
super().__init__(**kwargs)
60+
self._model = model
61+
self._llm = llm
62+
63+
@property
64+
def model(self) -> str:
65+
"""Get the current model identifier."""
66+
return self._model
67+
68+
@model.setter
69+
def model(self, value: str) -> None:
70+
"""Set the model identifier and reset LLM if it was auto-created."""
71+
self._model = value
72+
if self._llm is not None and self._llm._model != value:
73+
self._llm = None
4174

4275
@property
43-
def client(self) -> OpenAI:
44-
"""Cached OpenAI client instance."""
45-
if self._client is None:
46-
self._client = OpenAI()
47-
return self._client
76+
def llm(self) -> LLM:
77+
"""Get the LLM instance, creating one if needed."""
78+
if self._llm is None:
79+
self._llm = LLM(model=self._model, stop=["STOP", "END"])
80+
return self._llm
4881

4982
def _run(self, **kwargs) -> str:
5083
try:
5184
image_path_url = kwargs.get("image_path_url")
5285
if not image_path_url:
5386
return "Image Path or URL is required."
5487

55-
# Validate input using Pydantic
5688
ImagePromptSchema(image_path_url=image_path_url)
5789

5890
if image_path_url.startswith("http"):
@@ -64,8 +96,7 @@ def _run(self, **kwargs) -> str:
6496
except Exception as e:
6597
return f"Error processing image: {str(e)}"
6698

67-
response = self.client.chat.completions.create(
68-
model="gpt-4o-mini",
99+
response = self.llm.call(
69100
messages=[
70101
{
71102
"role": "user",
@@ -76,16 +107,21 @@ def _run(self, **kwargs) -> str:
76107
"image_url": {"url": image_data},
77108
},
78109
],
79-
}
110+
},
80111
],
81-
max_tokens=300,
82112
)
83-
84-
return response.choices[0].message.content
85-
113+
return response
86114
except Exception as e:
87115
return f"An error occurred: {str(e)}"
88116

89117
def _encode_image(self, image_path: str) -> str:
118+
"""Encode an image file as base64.
119+
120+
Args:
121+
image_path: Path to the image file
122+
123+
Returns:
124+
Base64-encoded image data
125+
"""
90126
with open(image_path, "rb") as image_file:
91127
return base64.b64encode(image_file.read()).decode("utf-8")

0 commit comments

Comments
 (0)