Skip to content

Commit 3ff0bc8

Browse files
committed
feat: add Gemini model integration & UI fixes
1 parent ed0e783 commit 3ff0bc8

File tree

6 files changed

+216
-73
lines changed

6 files changed

+216
-73
lines changed

app/llm.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,14 @@ def __init__(self):
5959
self.model = ModelFactory.create_model(self.model_name, base_url, api_key, context)
6060

6161
def get_settings_values(self) -> tuple[str, str, str]:
62-
model_name = self.settings_dict.get('model')
63-
if not model_name:
64-
model_name = DEFAULT_MODEL_NAME
65-
66-
base_url = self.settings_dict.get('base_url', '')
67-
if not base_url:
68-
base_url = 'https://api.openai.com/v1/'
69-
base_url = base_url.rstrip('/') + '/'
70-
62+
model_name = self.settings_dict.get('model') or DEFAULT_MODEL_NAME
63+
base_url = (self.settings_dict.get('base_url') or 'https://api.openai.com/v1/').rstrip('/') + '/'
7164
api_key = self.settings_dict.get('api_key')
7265

66+
if model_name.startswith('gemini'):
67+
api_key = self.settings_dict.get('gemini_api_key')
68+
model_name = self.settings_dict.get('gemini_model') or 'gemini-2.0-flash'
69+
7370
return model_name, base_url, api_key
7471

7572
def read_context_txt_file(self) -> str:

app/models/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from models.gpt4o import GPT4o
22
from models.gpt4v import GPT4v
3+
from models.gemini import Gemini
34

45

56
class ModelFactory:
@@ -10,6 +11,8 @@ def create_model(model_name, *args):
1011
return GPT4o(model_name, *args)
1112
elif model_name == 'gpt-4-vision-preview' or model_name == 'gpt-4-turbo':
1213
return GPT4v(model_name, *args)
14+
elif model_name.startswith("gemini"):
15+
return Gemini(model_name, *args[1:])
1316
else:
1417
# Llama/Llava models will work with the standard code I wrote for GPT4V without the assitant mode features of gpt4o
1518
return GPT4v(model_name, *args)

app/models/gemini.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import json
2+
from typing import Any
3+
4+
from google import genai
5+
from google.genai import types
6+
from utils.screen import Screen
7+
8+
9+
class Gemini:
10+
def __init__(self, model_name, api_key, context):
11+
self.model_name = model_name
12+
self.api_key = api_key
13+
self.context = context
14+
self.client = genai.Client(api_key=api_key)
15+
16+
def get_instructions_for_objective(self, original_user_request: str, step_num: int = 0) -> dict[str, Any]:
17+
safety_settings = [
18+
types.SafetySetting(category=category.value, threshold="BLOCK_NONE")
19+
for category in types.HarmCategory
20+
if category.value != 'HARM_CATEGORY_UNSPECIFIED'
21+
]
22+
message_content = self.format_user_request_for_llm(original_user_request, step_num)
23+
24+
llm_response = self.client.models.generate_content(
25+
model=self.model_name,
26+
contents=message_content,
27+
config=types.GenerateContentConfig(safety_settings=safety_settings),
28+
)
29+
json_instructions: dict[str, Any] = self.convert_llm_response_to_json_instructions(llm_response)
30+
return json_instructions
31+
32+
def format_user_request_for_llm(self, original_user_request, step_num) -> list[Any]:
33+
base64_img: str = Screen().get_screenshot_in_base64()
34+
35+
request_data: str = json.dumps({
36+
"original_user_request": original_user_request,
37+
"step_num": step_num,
38+
})
39+
40+
message_content = [
41+
{"text": self.context + request_data + "\n\nHere is a screenshot of the user's screen:"},
42+
{"inline_data": {
43+
"mime_type": "image/jpeg",
44+
"data": base64_img,
45+
}},
46+
]
47+
return message_content
48+
49+
def convert_llm_response_to_json_instructions(self, llm_response) -> dict[str, Any]:
50+
51+
llm_response_data = llm_response.text.strip()
52+
53+
start_index = llm_response_data.find("{")
54+
end_index = llm_response_data.rfind("}")
55+
56+
try:
57+
json_response = json.loads(llm_response_data[start_index:end_index + 1].strip())
58+
except Exception as e:
59+
print(f"Error while parsing JSON response - {e}")
60+
json_response = {}
61+
62+
return json_response
63+
64+
def cleanup(self):
65+
pass

app/ui.py

Lines changed: 118 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,41 @@ class AdvancedSettingsWindow(ttk.Toplevel):
3131
Self-contained settings sub-window for the UI
3232
"""
3333

34+
_instance = None
3435
def __init__(self, parent):
36+
if UI.AdvancedSettingsWindow._instance:
37+
return
3538
super().__init__(parent)
39+
UI.AdvancedSettingsWindow._instance = self
3640
self.title('Advanced Settings')
3741
self.minsize(300, 300)
3842
self.settings = Settings()
3943
self.create_widgets()
44+
self.protocol("WM_DELETE_WINDOW", self.on_close)
4045

4146
# Populate UI
4247
settings_dict = self.settings.get_dict()
4348

4449
if 'base_url' in settings_dict:
4550
self.base_url_entry.insert(0, settings_dict['base_url'])
46-
if 'model' in settings_dict:
47-
self.model_entry.insert(0, settings_dict['model'])
48-
self.model_var.set(settings_dict.get('model', 'custom'))
51+
if 'selected_model_radio' in settings_dict:
52+
self.model_var.set(settings_dict['selected_model_radio'])
53+
elif 'model' in settings_dict:
54+
loaded_model = settings_dict['model']
55+
if loaded_model in self.gemini_models:
56+
self.model_var.set('gemini')
57+
else:
58+
self.model_var.set('custom')
59+
self.model_entry.insert(0, loaded_model)
4960
else:
50-
self.model_entry.insert(0, DEFAULT_MODEL_NAME)
5161
self.model_var.set(DEFAULT_MODEL_NAME)
5262

63+
self.update_ui_visibility()
64+
65+
def on_close(self):
66+
UI.AdvancedSettingsWindow._instance = None
67+
self.destroy()
68+
5369
def create_widgets(self) -> None:
5470
# Radio buttons for model selection
5571
ttk.Label(self, text='Select Model:', bootstyle="primary").pack(pady=10, padx=10)
@@ -64,53 +80,111 @@ def create_widgets(self) -> None:
6480
('GPT-4o-mini (Cheapest, Fastest)', 'gpt-4o-mini'),
6581
('GPT-4v (Deprecated. Most-Accurate, Slowest)', 'gpt-4-vision-preview'),
6682
('GPT-4-Turbo (Least Accurate, Fast)', 'gpt-4-turbo'),
83+
('Gemini (Free, Fast)', 'gemini'),
6784
('Custom (Specify Settings Below)', 'custom')
6885
]
6986
for text, value in models:
70-
ttk.Radiobutton(radio_frame, text=text, value=value, variable=self.model_var, bootstyle="info").pack(
71-
anchor=ttk.W, pady=5)
87+
radio_button = ttk.Radiobutton(radio_frame, text=text, value=value, variable=self.model_var, bootstyle="info")
88+
radio_button.pack(anchor=ttk.W, pady=5)
89+
radio_button.config(command=self.update_ui_visibility)
90+
91+
# Gemini Model Selection
92+
self.gemini_model_label = ttk.Label(self, text='Select or Type Gemini Model:', bootstyle="primary")
93+
self.gemini_model_var = ttk.StringVar(value='gemini-2.0-flash')
94+
95+
self.gemini_model_frame = ttk.Frame(self)
96+
97+
self.gemini_models = [
98+
'gemini-2.0-flash',
99+
'gemini-2.0-flash-lite',
100+
'gemini-2.0-flash-thinking-exp',
101+
'gemini-2.0-pro-exp-02-05'
102+
]
103+
self.gemini_model_combobox = ttk.Combobox(self.gemini_model_frame, textvariable=self.gemini_model_var, values=self.gemini_models, width=30)
104+
self.gemini_model_combobox.pack(pady=5)
105+
self.gemini_model_combobox.set('gemini-2.0-flash')
72106

73-
label_base_url = ttk.Label(self, text='Custom OpenAI-Like API Model Base URL', bootstyle="secondary")
74-
label_base_url.pack(pady=10)
107+
# Custom Base URL
108+
self.custom_frame = ttk.Frame(self)
109+
110+
self.label_base_url = ttk.Label(self.custom_frame, text='Custom OpenAI-Like API Model Base URL', bootstyle="secondary")
111+
self.label_base_url.pack(pady=10)
75112

76113
# Entry for Base URL
77-
self.base_url_entry = ttk.Entry(self, width=30)
114+
self.base_url_entry = ttk.Entry(self.custom_frame, width=30)
78115
self.base_url_entry.pack()
79116

80117
# Model Label
81-
label_model = ttk.Label(self, text='Custom Model Name:', bootstyle="secondary")
82-
label_model.pack(pady=10)
118+
self.label_model = ttk.Label(self.custom_frame, text='Custom Model Name:', bootstyle="secondary")
119+
self.label_model.pack(pady=10)
83120

84121
# Entry for Model
85-
self.model_entry = ttk.Entry(self, width=30)
122+
self.model_entry = ttk.Entry(self.custom_frame, width=30)
86123
self.model_entry.pack()
87124

88125
# Save Button
89126
save_button = ttk.Button(self, text='Save Settings', bootstyle="success", command=self.save_button)
90127
save_button.pack(pady=20)
91128

129+
def update_ui_visibility(self):
130+
"""Update the visibility of UI elements based on model selection"""
131+
model_choice = self.model_var.get()
132+
133+
# First hide all optional frames
134+
self.gemini_model_label.pack_forget()
135+
self.gemini_model_frame.pack_forget()
136+
self.custom_frame.pack_forget()
137+
138+
# Then show only the relevant ones
139+
if model_choice == 'gemini':
140+
self.gemini_model_label.pack(pady=10, padx=10)
141+
self.gemini_model_frame.pack(padx=20, pady=10)
142+
if 'gemini_model' in self.settings.get_dict():
143+
self.gemini_model_combobox.set(self.settings.get_dict()['gemini_model'])
144+
145+
if model_choice == 'custom':
146+
self.custom_frame.pack(padx=20, pady=10)
147+
previous_model = self.settings.get_dict().get('model', DEFAULT_MODEL_NAME)
148+
if previous_model not in self.gemini_models:
149+
self.model_entry.delete(0, ttk.END)
150+
self.model_entry.insert(0, previous_model)
151+
92152
def save_button(self) -> None:
93153
base_url = self.base_url_entry.get().strip()
94-
model = self.model_var.get() if self.model_var.get() != 'custom' else self.model_entry.get().strip()
154+
model = self.model_var.get()
155+
156+
if model == 'custom':
157+
model = self.model_entry.get().strip()
158+
elif model == 'gemini':
159+
model = self.gemini_model_var.get().strip()
160+
95161
settings_dict = {
96162
'base_url': base_url,
97163
'model': model,
164+
'gemini_model': self.gemini_model_var.get().strip(),
165+
'selected_model_radio': self.model_var.get()
98166
}
99167

100168
self.settings.save_settings_to_file(settings_dict)
101169
self.destroy()
170+
UI.AdvancedSettingsWindow._instance = None
102171

103172
class SettingsWindow(ttk.Toplevel):
104173
"""
105174
Self-contained settings sub-window for the UI
106175
"""
107176

177+
_instance = None
108178
def __init__(self, parent):
179+
if UI.SettingsWindow._instance:
180+
return
109181
super().__init__(parent)
182+
UI.SettingsWindow._instance = self
110183
self.title('Settings')
111184
self.minsize(300, 450)
112185
self.available_themes = ['darkly', 'cyborg', 'journal', 'solar', 'superhero']
113186
self.create_widgets()
187+
self.protocol("WM_DELETE_WINDOW", self.on_close)
114188

115189
self.settings = Settings()
116190

@@ -119,20 +193,31 @@ def __init__(self, parent):
119193

120194
if 'api_key' in settings_dict:
121195
self.api_key_entry.insert(0, settings_dict['api_key'])
196+
if 'gemini_api_key' in settings_dict:
197+
self.gemini_api_key_entry.insert(0, settings_dict['gemini_api_key'])
122198
if 'default_browser' in settings_dict:
123199
self.browser_combobox.set(settings_dict['default_browser'])
124200
if 'play_ding_on_completion' in settings_dict:
125201
self.play_ding.set(1 if settings_dict['play_ding_on_completion'] else 0)
126-
if 'custom_llm_instructions':
202+
if 'custom_llm_instructions' in settings_dict:
127203
self.llm_instructions_text.insert('1.0', settings_dict['custom_llm_instructions'])
128204
self.theme_combobox.set(settings_dict.get('theme', 'superhero'))
205+
206+
def on_close(self):
207+
UI.SettingsWindow._instance = None
208+
self.destroy()
209+
210+
def create_label_and_entry(self, parent, label_text):
211+
ttk.Label(parent, text=label_text, bootstyle="info").pack(pady=10)
212+
entry = ttk.Entry(parent, width=30)
213+
entry.pack()
214+
return entry
129215

130216
def create_widgets(self) -> None:
131217
# API Key Widgets
132-
label_api = ttk.Label(self, text='OpenAI API Key:', bootstyle="info")
133-
label_api.pack(pady=10)
134-
self.api_key_entry = ttk.Entry(self, width=30)
135-
self.api_key_entry.pack()
218+
self.api_key_entry = self.create_label_and_entry(self, 'OpenAI API Key:')
219+
220+
self.gemini_api_key_entry = self.create_label_and_entry(self, 'Gemini API Key:')
136221

137222
# Label for Browser Choice
138223
label_browser = ttk.Label(self, text='Choose Default Browser:', bootstyle="info")
@@ -179,45 +264,41 @@ def create_widgets(self) -> None:
179264
command=self.open_advanced_settings)
180265
advanced_settings_button.pack(pady=(0, 10))
181266

182-
# Hyperlink Label
183-
link_label = ttk.Label(self, text='Setup Instructions', bootstyle="primary")
184-
link_label.pack()
185-
link_label.bind('<Button-1>', lambda e: open_link(
186-
'https://github.com/AmberSahdev/Open-Interface?tab=readme-ov-file#setup-%EF%B8%8F'))
187-
188-
# Check for updates Label
189-
update_label = ttk.Label(self, text='Check for Updates', bootstyle="primary")
190-
update_label.pack()
191-
update_label.bind('<Button-1>', lambda e: open_link(
192-
'https://github.com/AmberSahdev/Open-Interface/releases/latest'))
267+
# Create clickable labels
268+
self.create_link_label('Setup Instructions',
269+
'https://github.com/AmberSahdev/Open-Interface?tab=readme-ov-file#setup-%EF%B8%8F')
270+
self.create_link_label('Check for Updates',
271+
'https://github.com/AmberSahdev/Open-Interface/releases/latest')
193272

194273
# Version Label
195274
version_label = ttk.Label(self, text=f'Version: {str(version)}', font=('Helvetica', 10))
196275
version_label.pack(side="bottom", pady=10)
197276

277+
def create_link_label(self, text, url):
278+
link_label = ttk.Label(self, text=text, bootstyle="primary")
279+
link_label.pack()
280+
link_label.bind('<Button-1>', lambda e: open_link(url))
281+
198282
def on_theme_change(self, event=None) -> None:
199283
# Apply theme immediately when selected
200284
theme = self.theme_var.get()
201285
self.master.change_theme(theme)
202286

203287
def save_button(self) -> None:
204-
theme = self.theme_var.get()
205-
api_key = self.api_key_entry.get().strip()
206-
default_browser = self.browser_var.get()
207288
settings_dict = {
208-
'api_key': api_key,
209-
'default_browser': default_browser,
289+
'api_key': self.api_key_entry.get().strip(),
290+
'gemini_api_key': self.gemini_api_key_entry.get().strip(),
291+
'default_browser': self.browser_var.get(),
210292
'play_ding_on_completion': bool(self.play_ding.get()),
211293
'custom_llm_instructions': self.llm_instructions_text.get("1.0", "end-1c").strip(),
212-
'theme': theme
294+
'theme': self.theme_var.get()
213295
}
214296

215-
# Remove redundant theme change since it's already applied
216297
self.settings.save_settings_to_file(settings_dict)
217298
self.destroy()
299+
UI.SettingsWindow._instance = None
218300

219301
def open_advanced_settings(self):
220-
# Open the advanced settings window
221302
UI.AdvancedSettingsWindow(self)
222303

223304
class MainWindow(ttk.Window):

0 commit comments

Comments
 (0)