Skip to content

Commit f28efde

Browse files
authored
Queries all selected models at same time
In earlier script, even though we select 4 models, it queries first model, waits for its output, after it is done, it moves to second model and waits for its output, after it is done, it moves to third model and so on.. Changed it to Parallel execution so that all models are queried at the same time.
1 parent ae399d9 commit f28efde

File tree

1 file changed

+104
-50
lines changed

1 file changed

+104
-50
lines changed

Vertical View - app.py

Lines changed: 104 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import streamlit as st
22
import requests
33
import time
4+
import concurrent.futures
45

56
st.set_page_config(page_title="LLM Comparison", layout="wide")
67

@@ -50,92 +51,145 @@ def get_models():
5051

5152
models_available = get_models()
5253

54+
if not models_available:
55+
st.warning("No models found. Ensure Ollama is running and has models pulled.")
56+
st.stop()
57+
5358
if "model_count" not in st.session_state:
5459
st.session_state.model_count = 2
5560
if "selected_models" not in st.session_state:
56-
st.session_state.selected_models = ["", ""]
61+
st.session_state.selected_models = [""] * st.session_state.model_count
5762

63+
# Logic to add a new model selection
5864
if st.button("Add new model"):
5965
st.session_state.model_count += 1
6066
st.session_state.selected_models.append("")
6167

68+
# Display model selection boxes and remove buttons
6269
for i in range(st.session_state.model_count):
6370
with st.container():
6471
col1, col2 = st.columns([0.9, 0.1])
6572
with col1:
73+
# Ensure selected_models list is long enough for the current index
74+
if i >= len(st.session_state.selected_models):
75+
st.session_state.selected_models.append("")
76+
77+
# Determine the initial selection for the selectbox
78+
current_selection_index = 0
79+
if st.session_state.selected_models[i] in models_available:
80+
current_selection_index = models_available.index(st.session_state.selected_models[i])
81+
elif models_available: # If previous selection isn't available, default to first available
82+
st.session_state.selected_models[i] = models_available[0]
83+
current_selection_index = 0
84+
6685
st.session_state.selected_models[i] = st.selectbox(
6786
f"Model {i+1}",
6887
models_available,
69-
index=models_available.index(st.session_state.selected_models[i]) if st.session_state.selected_models[i] in models_available else 0,
88+
index=current_selection_index,
7089
key=f"model_select_{i}"
7190
)
7291
with col2:
7392
if st.button("x", key=f"remove_model_{i}"):
7493
if st.session_state.model_count > 1:
7594
st.session_state.selected_models.pop(i)
7695
st.session_state.model_count -= 1
77-
st.rerun()
96+
st.rerun() # Rerun to update the UI immediately after removal
7897

7998
run = st.button("Run Models", type="primary")
8099

81100
# Main display area
82101
st.title("Running LLMs in parallel")
83102

103+
# Function to query a single Ollama model
104+
def query_ollama_model(model_name, prompt_text):
105+
"""Function to query a single Ollama model."""
106+
try:
107+
start_time = time.time()
108+
res = requests.post(
109+
"http://localhost:11434/api/generate",
110+
json={"model": model_name, "prompt": prompt_text, "stream": False},
111+
headers={"Content-Type": "application/json"},
112+
)
113+
res.raise_for_status()
114+
response_data = res.json()
115+
end_time = time.time()
116+
117+
duration = round(end_time - start_time, 2)
118+
content = response_data.get("response", "")
119+
# Fallback for eval_count if not present, approximate by word count
120+
eval_count = response_data.get("eval_count", len(content.split()))
121+
eval_rate = response_data.get("eval_rate", round(eval_count / duration, 2) if duration > 0 else 0)
122+
123+
return {
124+
"model": model_name,
125+
"duration": duration,
126+
"eval_count": eval_count,
127+
"eval_rate": eval_rate,
128+
"response": content
129+
}
130+
except Exception as e:
131+
return {
132+
"model": model_name,
133+
"duration": 0,
134+
"eval_count": 0,
135+
"eval_rate": 0,
136+
"response": f"Error: {e}"
137+
}
138+
84139
if run and prompt.strip():
85140
model_inputs = [model for model in st.session_state.selected_models if model]
86141

87142
if not model_inputs:
88143
st.warning("Please select at least one model to generate a response.")
89144
else:
145+
# Create columns dynamically based on the number of selected models
90146
cols = st.columns(len(model_inputs))
91-
placeholders = [col.empty() for col in cols] # Create placeholders for each model's output
147+
148+
# Use a dictionary to store placeholder containers for each model's output
149+
model_output_containers = {}
92150

93151
for i, model in enumerate(model_inputs):
94-
with placeholders[i].container():
152+
with cols[i]:
95153
model_color = "blue" if i % 2 == 0 else "red"
96-
st.markdown(
97-
f"<h3 style='color:{model_color};'>{model}</h3>",
98-
unsafe_allow_html=True
99-
)
154+
# Display the model name ONCE at the top of its column
155+
st.markdown(f"<h3 style='color:{model_color};'>{model}</h3>", unsafe_allow_html=True)
156+
157+
# Create an empty placeholder where the spinner and later the content will live
158+
model_output_containers[model] = st.empty()
100159

101-
# Use st.spinner for the loading indicator
102-
with st.spinner(f"Running {model}..."):
103-
start = time.time()
104-
try:
105-
response = requests.post(
106-
"http://localhost:11434/api/generate",
107-
json={"model": model, "prompt": prompt, "stream": False},
108-
).json()
109-
110-
duration = round(time.time() - start, 2)
111-
content = response.get("response", "").strip()
112-
eval_count = response.get("eval_count", len(content.split()))
113-
eval_rate = response.get("eval_rate", round(eval_count / duration, 2))
114-
115-
# Clear the spinner and display the actual response
116-
placeholders[i].empty() # Clear the placeholder content including the spinner
117-
with placeholders[i].container(): # Redraw content
118-
st.markdown(
119-
f"<h3 style='color:{model_color};'>{model}</h3>",
120-
unsafe_allow_html=True
121-
)
122-
st.markdown(
123-
f"""
124-
<div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;">
125-
<b>Duration</b>: <span style="color:#3366cc;">{duration} secs</span><br>
126-
<b>Eval count</b>: <span style="color:green;">{eval_count} tokens</span><br>
127-
<b>Eval rate</b>: <span style="color:green;">{eval_rate} tokens/s</span>
128-
</div>
129-
""",
130-
unsafe_allow_html=True
131-
)
132-
st.write(content)
133-
134-
except Exception as e:
135-
placeholders[i].empty() # Clear the placeholder content including the spinner
136-
with placeholders[i].container(): # Redraw content
137-
st.markdown(
138-
f"<h3 style='color:{model_color};'>{model}</h3>",
139-
unsafe_allow_html=True
140-
)
141-
st.error(f"Error: {e}")
160+
# Show spinner in this container
161+
with model_output_containers[model].container():
162+
st.spinner(f"Running {model}...")
163+
164+
165+
# Use ThreadPoolExecutor for concurrent execution
166+
all_results = []
167+
with concurrent.futures.ThreadPoolExecutor(max_workers=len(model_inputs)) as executor:
168+
# Map models to futures
169+
future_to_model = {executor.submit(query_ollama_model, model, prompt): model for model in model_inputs}
170+
171+
# As futures complete, update the respective placeholders
172+
for future in concurrent.futures.as_completed(future_to_model):
173+
model_name = future_to_model[future]
174+
result = future.result() # Get the result (either success or error dict)
175+
all_results.append(result)
176+
177+
# Update the content of the specific model's container
178+
# The .empty() call on the container is implicitly handled by writing new content to it.
179+
with model_output_containers[model_name].container():
180+
# No need to re-display model name here, it's already at the top of the column
181+
182+
if "Error" in result["response"]:
183+
st.error(f"Error: {result['response']}")
184+
else:
185+
st.markdown(
186+
f"""
187+
<div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;">
188+
<b>Duration</b>: <span style="color:#3366cc;">{result['duration']} secs</span><br>
189+
<b>Eval count</b>: <span style="color:green;">{result['eval_count']} tokens</span><br>
190+
<b>Eval rate</b>: <span style="color:green;">{result['eval_rate']} tokens/s</span>
191+
</div>
192+
""",
193+
unsafe_allow_html=True
194+
)
195+
st.write(result["response"])

0 commit comments

Comments
 (0)