Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion omnitool/gradio/agent/llm_utils/oaiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,86 @@ def run_oai_interleaved(messages: list, system: str, model_name: str, api_key: s
return text, token_usage
except Exception as e:
print(f"Error in interleaved openAI: {e}. This may due to your invalid API key. Please check the response: {response.json()} ")
return response.json()
return response.json()

def run_azure_oai_interleaved(
messages: list,
system: str,
deployment_name: str,
api_key: str,
api_version: str = "2025-01-01-preview",
resource_name: str = None,
max_tokens: int = 256,
temperature: float = 0
):
"""
Azure OpenAI version of run_oai_interleaved
Args:
messages: List of messages or single message string
system: System message
deployment_name: Azure OpenAI deployment name
api_key: Azure OpenAI API key
api_version: API version to use
resource_name: Azure OpenAI resource name
max_tokens: Maximum tokens for completion
temperature: Temperature for response generation
"""
if not resource_name:
raise ValueError("resource_name is required for Azure OpenAI")

headers = {
"Content-Type": "application/json",
"api-key": api_key
}

# Base URL construction for Azure
provider_base_url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_name}"

final_messages = [{"role": "system", "content": system}]

if type(messages) == list:
for item in messages:
contents = []
if isinstance(item, dict):
for cnt in item["content"]:
if isinstance(cnt, str):
if is_image_path(cnt) and 'gpt-4-vision' in deployment_name.lower():
base64_image = encode_image(cnt)
content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
else:
content = {"type": "text", "text": cnt}
else:
content = {"type": "text", "text": str(cnt)}
contents.append(content)
message = {"role": 'user', "content": contents}
else: # str
contents.append({"type": "text", "text": item})
message = {"role": "user", "content": contents}
final_messages.append(message)
elif isinstance(messages, str):
final_messages = [{"role": "user", "content": messages}]

payload = {
"messages": final_messages,
"max_tokens": max_tokens,
"temperature": temperature
}

response = requests.post(
f"{provider_base_url}/chat/completions?api-version={api_version}",
headers=headers,
json=payload
)

try:
text = response.json()['choices'][0]['message']['content']
token_usage = int(response.json().get('usage', {}).get('total_tokens', 0))
return text, token_usage
except Exception as e:
print(f"Error in Azure OpenAI call: {e}. Response: {response.json()}")
return response.json()['choices'][0]['message']['content']
61 changes: 47 additions & 14 deletions omnitool/gradio/agent/vlm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from anthropic.types import ToolResultBlockParam
from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam, BetaUsage

from agent.llm_utils.oaiclient import run_oai_interleaved
from agent.llm_utils.oaiclient import run_oai_interleaved, run_azure_oai_interleaved
from agent.llm_utils.groqclient import run_groq_interleaved
from agent.llm_utils.utils import is_image_path
import time
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(
max_tokens: int = 4096,
only_n_most_recent_images: int | None = None,
print_usage: bool = True,
azure_resource_name: str = None,
):
if model == "omniparser + gpt-4o":
self.model = "gpt-4o-2024-11-20"
Expand All @@ -59,6 +60,7 @@ def __init__(
self.max_tokens = max_tokens
self.only_n_most_recent_images = only_n_most_recent_images
self.output_callback = output_callback
self.azure_resource_name = azure_resource_name

self.print_usage = print_usage
self.total_token_usage = 0
Expand Down Expand Up @@ -92,23 +94,54 @@ def __call__(self, messages: list, parsed_screen: list[str, list, dict]):

start = time.time()
if "gpt" in self.model or "o1" in self.model or "o3-mini" in self.model:
vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
provider_base_url="https://api.openai.com/v1",
temperature=0,
)
print(f"oai token usage: {token_usage}")
# Map model names to Azure deployment names
deployment_name = {
"gpt-4o-2024-11-20": "gpt-4o-0628", # adjust to your actual deployment name
"o1": "o1",
"o3-mini": "o3-mini"
}.get(self.model)

if self.provider == "azure":
if not self.azure_resource_name:
raise ValueError("azure_resource_name is required when using Azure OpenAI")

try:
result = run_azure_oai_interleaved(
messages=planner_messages,
system=system,
deployment_name=deployment_name,
api_key=self.api_key,
resource_name=self.azure_resource_name,
max_tokens=self.max_tokens,
)

if isinstance(result, tuple):
vlm_response, token_usage = result
else:
# Handle error case or missing token usage
vlm_response = result
token_usage = 0 # Default to 0 if not provided
print("Warning: Token usage information not available from Azure OpenAI")
except Exception as e:
print(f"Error in Azure OpenAI call: {e}")
raise e

else: # openai
vlm_response, token_usage = run_oai_interleaved(
messages=planner_messages,
system=system,
model_name=self.model,
api_key=self.api_key,
max_tokens=self.max_tokens,
)
print(f"{self.provider} oai token usage: {token_usage}")
self.total_token_usage += token_usage
if 'gpt' in self.model:
self.total_cost += (token_usage * 2.5 / 1000000) # https://openai.com/api/pricing/
self.total_cost += (token_usage * 2.5 / 1000000)
elif 'o1' in self.model:
self.total_cost += (token_usage * 15 / 1000000) # https://openai.com/api/pricing/
self.total_cost += (token_usage * 15 / 1000000)
elif 'o3-mini' in self.model:
self.total_cost += (token_usage * 1.1 / 1000000) # https://openai.com/api/pricing/
self.total_cost += (token_usage * 1.1 / 1000000)
elif "r1" in self.model:
vlm_response, token_usage = run_groq_interleaved(
messages=planner_messages,
Expand Down
62 changes: 50 additions & 12 deletions omnitool/gradio/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
API_KEY_FILE = CONFIG_DIR / "api_key"

INTRO_TEXT = '''
OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini), DeepSeek (R1), Qwen (2.5VL) or Anthropic Computer Use (Sonnet).**
OmniParser lets you turn any vision-langauge model into an AI agent. We currently support **OpenAI (4o/o1/o3-mini) [Both Azure and OpenAI APIs], DeepSeek (R1), Qwen (2.5VL) or Anthropic Computer Use (Sonnet).**

Type a message and press submit to start OmniTool. Press stop to pause, and press the trash icon in the chat to clear the message history.
'''
Expand Down Expand Up @@ -59,7 +59,10 @@ def setup_state(state):
if "anthropic_api_key" not in state:
state["anthropic_api_key"] = os.getenv("ANTHROPIC_API_KEY", "")
if "api_key" not in state:
state["api_key"] = ""
if state.get("provider") in ["openai", "azure"]:
state["api_key"] = os.getenv("OPENAI_API_KEY", "")
else:
state["api_key"] = ""
if "auth_validated" not in state:
state["auth_validated"] = False
if "responses" not in state:
Expand All @@ -72,6 +75,8 @@ def setup_state(state):
state['chatbot_messages'] = []
if 'stop' not in state:
state['stop'] = False
if "azure_resource_name" not in state:
state["azure_resource_name"] = ""

async def main(state):
"""Render loop for Gradio"""
Expand Down Expand Up @@ -202,6 +207,9 @@ def valid_params(user_input, state):
if not state["api_key"].strip():
errors.append("LLM API Key is not set")

if state["provider"] == "azure" and not state.get("azure_resource_name", "").strip():
errors.append("Azure Resource Name is required when using Azure OpenAI")

if not user_input:
errors.append("no computer use request provided")

Expand Down Expand Up @@ -241,8 +249,9 @@ def process_input(user_input, state):
api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
api_key=state["api_key"],
only_n_most_recent_images=state["only_n_most_recent_images"],
max_tokens=16384,
omniparser_url=args.omniparser_server_url
max_tokens=4096,
omniparser_url=args.omniparser_server_url,
azure_resource_name=state.get("azure_resource_name") if state["provider"] == "azure" else None
):
if loop_msg is None or state.get("stop"):
yield state['chatbot_messages']
Expand Down Expand Up @@ -331,6 +340,14 @@ def get_header_image_base64():
placeholder="Paste your API key here",
interactive=True,
)
with gr.Row():
azure_resource_name = gr.Textbox(
label="Azure Resource Name",
value="",
placeholder="Required for Azure OpenAI",
interactive=True,
visible=False # Initially hidden
)

with gr.Row():
with gr.Column(scale=8):
Expand All @@ -357,7 +374,7 @@ def update_model(model_selection, state):
if model_selection == "claude-3-5-sonnet-20241022":
provider_choices = [option.value for option in APIProvider if option.value != "openai"]
elif model_selection in set(["omniparser + gpt-4o", "omniparser + o1", "omniparser + o3-mini"]):
provider_choices = ["openai"]
provider_choices = ["openai", "azure"]
elif model_selection == "omniparser + R1":
provider_choices = ["groq"]
elif model_selection == "omniparser + qwen2.5vl":
Expand All @@ -384,26 +401,46 @@ def update_model(model_selection, state):
value=state["api_key"]
)

return provider_update, api_key_update
# Add Azure visibility update
azure_visible = default_provider_value == "azure"
azure_resource_update = gr.update(visible=azure_visible)


return provider_update, api_key_update, azure_resource_update

def update_only_n_images(only_n_images_value, state):
state["only_n_most_recent_images"] = only_n_images_value

def update_provider(provider_value, state):
# Update state
state["provider"] = provider_value
state["api_key"] = state.get(f"{provider_value}_api_key", "")
if provider_value in ["openai", "azure"]:
state["api_key"] = state.get("openai_api_key", "")
else:
state["api_key"] = state.get(f"{provider_value}_api_key", "")

# Update Azure resource name visibility
azure_visible = provider_value == "azure"
azure_resource_update = gr.update(visible=azure_visible)

# Calls to update other components UI
api_key_update = gr.update(
placeholder=f"{provider_value.title()} API Key",
placeholder="OpenAI API Key", # Keep it as OpenAI API Key for both
value=state["api_key"]
)
return api_key_update


return api_key_update, azure_resource_update

def update_azure_resource(resource_name, state):
state["azure_resource_name"] = resource_name

def update_api_key(api_key_value, state):
state["api_key"] = api_key_value
state[f'{state["provider"]}_api_key'] = api_key_value
# Store in openai_api_key if provider is openai or azure
if state["provider"] in ["openai", "azure"]:
state["openai_api_key"] = api_key_value
else:
state[f'{state["provider"]}_api_key'] = api_key_value

def clear_chat(state):
# Reset message-related state
Expand All @@ -415,8 +452,9 @@ def clear_chat(state):

model.change(fn=update_model, inputs=[model, state], outputs=[provider, api_key])
only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
provider.change(fn=update_provider, inputs=[provider, state], outputs=api_key)
provider.change(fn=update_provider, inputs=[provider, state], outputs=[api_key, azure_resource_name])
api_key.change(fn=update_api_key, inputs=[api_key, state], outputs=None)
azure_resource_name.change(fn=update_azure_resource, inputs=[azure_resource_name, state], outputs=None)
chatbot.clear(fn=clear_chat, inputs=[state], outputs=[chatbot])

submit_button.click(process_input, [chat_input, state], chatbot)
Expand Down
8 changes: 6 additions & 2 deletions omnitool/gradio/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ class APIProvider(StrEnum):
BEDROCK = "bedrock"
VERTEX = "vertex"
OPENAI = "openai"
AZURE = "azure"


PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
APIProvider.OPENAI: "gpt-4o",
APIProvider.AZURE: "gpt-4o"
}

def sampling_loop_sync(
Expand All @@ -47,7 +49,8 @@ def sampling_loop_sync(
api_key: str,
only_n_most_recent_images: int | None = 2,
max_tokens: int = 4096,
omniparser_url: str
omniparser_url: str,
azure_resource_name: str = None
):
"""
Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
Expand All @@ -72,7 +75,8 @@ def sampling_loop_sync(
api_response_callback=api_response_callback,
output_callback=output_callback,
max_tokens=max_tokens,
only_n_most_recent_images=only_n_most_recent_images
only_n_most_recent_images=only_n_most_recent_images,
azure_resource_name=azure_resource_name if provider == APIProvider.AZURE else None
)
else:
raise ValueError(f"Model {model} not supported")
Expand Down