|
1 | 1 | import streamlit as st |
2 | 2 | import requests |
3 | 3 | import time |
| 4 | +import concurrent.futures |
4 | 5 |
|
5 | 6 | st.set_page_config(page_title="LLM Comparison", layout="wide") |
6 | 7 |
|
@@ -50,92 +51,145 @@ def get_models(): |
50 | 51 |
|
51 | 52 | models_available = get_models() |
52 | 53 |
|
| 54 | + if not models_available: |
| 55 | + st.warning("No models found. Ensure Ollama is running and has models pulled.") |
| 56 | + st.stop() |
| 57 | + |
53 | 58 | if "model_count" not in st.session_state: |
54 | 59 | st.session_state.model_count = 2 |
55 | 60 | if "selected_models" not in st.session_state: |
56 | | - st.session_state.selected_models = ["", ""] |
| 61 | + st.session_state.selected_models = [""] * st.session_state.model_count |
57 | 62 |
|
| 63 | + # Logic to add a new model selection |
58 | 64 | if st.button("Add new model"): |
59 | 65 | st.session_state.model_count += 1 |
60 | 66 | st.session_state.selected_models.append("") |
61 | 67 |
|
| 68 | + # Display model selection boxes and remove buttons |
62 | 69 | for i in range(st.session_state.model_count): |
63 | 70 | with st.container(): |
64 | 71 | col1, col2 = st.columns([0.9, 0.1]) |
65 | 72 | with col1: |
| 73 | + # Ensure selected_models list is long enough for the current index |
| 74 | + if i >= len(st.session_state.selected_models): |
| 75 | + st.session_state.selected_models.append("") |
| 76 | + |
| 77 | + # Determine the initial selection for the selectbox |
| 78 | + current_selection_index = 0 |
| 79 | + if st.session_state.selected_models[i] in models_available: |
| 80 | + current_selection_index = models_available.index(st.session_state.selected_models[i]) |
| 81 | + elif models_available: # If previous selection isn't available, default to first available |
| 82 | + st.session_state.selected_models[i] = models_available[0] |
| 83 | + current_selection_index = 0 |
| 84 | + |
66 | 85 | st.session_state.selected_models[i] = st.selectbox( |
67 | 86 | f"Model {i+1}", |
68 | 87 | models_available, |
69 | | - index=models_available.index(st.session_state.selected_models[i]) if st.session_state.selected_models[i] in models_available else 0, |
| 88 | + index=current_selection_index, |
70 | 89 | key=f"model_select_{i}" |
71 | 90 | ) |
72 | 91 | with col2: |
73 | 92 | if st.button("x", key=f"remove_model_{i}"): |
74 | 93 | if st.session_state.model_count > 1: |
75 | 94 | st.session_state.selected_models.pop(i) |
76 | 95 | st.session_state.model_count -= 1 |
77 | | - st.rerun() |
| 96 | + st.rerun() # Rerun to update the UI immediately after removal |
78 | 97 |
|
79 | 98 | run = st.button("Run Models", type="primary") |
80 | 99 |
|
81 | 100 | # Main display area |
82 | 101 | st.title("Running LLMs in parallel") |
83 | 102 |
|
| 103 | +# Function to query a single Ollama model |
| 104 | +def query_ollama_model(model_name, prompt_text): |
| 105 | + """Function to query a single Ollama model.""" |
| 106 | + try: |
| 107 | + start_time = time.time() |
| 108 | + res = requests.post( |
| 109 | + "http://localhost:11434/api/generate", |
| 110 | + json={"model": model_name, "prompt": prompt_text, "stream": False}, |
| 111 | + headers={"Content-Type": "application/json"}, |
| 112 | + ) |
| 113 | + res.raise_for_status() |
| 114 | + response_data = res.json() |
| 115 | + end_time = time.time() |
| 116 | + |
| 117 | + duration = round(end_time - start_time, 2) |
| 118 | + content = response_data.get("response", "") |
| 119 | + # Fallback for eval_count if not present, approximate by word count |
| 120 | + eval_count = response_data.get("eval_count", len(content.split())) |
| 121 | + eval_rate = response_data.get("eval_rate", round(eval_count / duration, 2) if duration > 0 else 0) |
| 122 | + |
| 123 | + return { |
| 124 | + "model": model_name, |
| 125 | + "duration": duration, |
| 126 | + "eval_count": eval_count, |
| 127 | + "eval_rate": eval_rate, |
| 128 | + "response": content |
| 129 | + } |
| 130 | + except Exception as e: |
| 131 | + return { |
| 132 | + "model": model_name, |
| 133 | + "duration": 0, |
| 134 | + "eval_count": 0, |
| 135 | + "eval_rate": 0, |
| 136 | + "response": f"Error: {e}" |
| 137 | + } |
| 138 | + |
84 | 139 | if run and prompt.strip(): |
85 | 140 | model_inputs = [model for model in st.session_state.selected_models if model] |
86 | 141 |
|
87 | 142 | if not model_inputs: |
88 | 143 | st.warning("Please select at least one model to generate a response.") |
89 | 144 | else: |
| 145 | + # Create columns dynamically based on the number of selected models |
90 | 146 | cols = st.columns(len(model_inputs)) |
91 | | - placeholders = [col.empty() for col in cols] # Create placeholders for each model's output |
| 147 | + |
| 148 | + # Use a dictionary to store placeholder containers for each model's output |
| 149 | + model_output_containers = {} |
92 | 150 |
|
93 | 151 | for i, model in enumerate(model_inputs): |
94 | | - with placeholders[i].container(): |
| 152 | + with cols[i]: |
95 | 153 | model_color = "blue" if i % 2 == 0 else "red" |
96 | | - st.markdown( |
97 | | - f"<h3 style='color:{model_color};'>{model}</h3>", |
98 | | - unsafe_allow_html=True |
99 | | - ) |
| 154 | + # Display the model name ONCE at the top of its column |
| 155 | + st.markdown(f"<h3 style='color:{model_color};'>{model}</h3>", unsafe_allow_html=True) |
| 156 | + |
| 157 | + # Create an empty placeholder where the spinner and later the content will live |
| 158 | + model_output_containers[model] = st.empty() |
100 | 159 |
|
101 | | - # Use st.spinner for the loading indicator |
102 | | - with st.spinner(f"Running {model}..."): |
103 | | - start = time.time() |
104 | | - try: |
105 | | - response = requests.post( |
106 | | - "http://localhost:11434/api/generate", |
107 | | - json={"model": model, "prompt": prompt, "stream": False}, |
108 | | - ).json() |
109 | | - |
110 | | - duration = round(time.time() - start, 2) |
111 | | - content = response.get("response", "").strip() |
112 | | - eval_count = response.get("eval_count", len(content.split())) |
113 | | - eval_rate = response.get("eval_rate", round(eval_count / duration, 2)) |
114 | | - |
115 | | - # Clear the spinner and display the actual response |
116 | | - placeholders[i].empty() # Clear the placeholder content including the spinner |
117 | | - with placeholders[i].container(): # Redraw content |
118 | | - st.markdown( |
119 | | - f"<h3 style='color:{model_color};'>{model}</h3>", |
120 | | - unsafe_allow_html=True |
121 | | - ) |
122 | | - st.markdown( |
123 | | - f""" |
124 | | - <div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;"> |
125 | | - <b>Duration</b>: <span style="color:#3366cc;">{duration} secs</span><br> |
126 | | - <b>Eval count</b>: <span style="color:green;">{eval_count} tokens</span><br> |
127 | | - <b>Eval rate</b>: <span style="color:green;">{eval_rate} tokens/s</span> |
128 | | - </div> |
129 | | - """, |
130 | | - unsafe_allow_html=True |
131 | | - ) |
132 | | - st.write(content) |
133 | | - |
134 | | - except Exception as e: |
135 | | - placeholders[i].empty() # Clear the placeholder content including the spinner |
136 | | - with placeholders[i].container(): # Redraw content |
137 | | - st.markdown( |
138 | | - f"<h3 style='color:{model_color};'>{model}</h3>", |
139 | | - unsafe_allow_html=True |
140 | | - ) |
141 | | - st.error(f"Error: {e}") |
| 160 | + # Show spinner in this container |
| 161 | + with model_output_containers[model].container(): |
| 162 | + st.spinner(f"Running {model}...") |
| 163 | + |
| 164 | + |
| 165 | + # Use ThreadPoolExecutor for concurrent execution |
| 166 | + all_results = [] |
| 167 | + with concurrent.futures.ThreadPoolExecutor(max_workers=len(model_inputs)) as executor: |
| 168 | + # Map models to futures |
| 169 | + future_to_model = {executor.submit(query_ollama_model, model, prompt): model for model in model_inputs} |
| 170 | + |
| 171 | + # As futures complete, update the respective placeholders |
| 172 | + for future in concurrent.futures.as_completed(future_to_model): |
| 173 | + model_name = future_to_model[future] |
| 174 | + result = future.result() # Get the result (either success or error dict) |
| 175 | + all_results.append(result) |
| 176 | + |
| 177 | + # Update the content of the specific model's container |
| 178 | + # The .empty() call on the container is implicitly handled by writing new content to it. |
| 179 | + with model_output_containers[model_name].container(): |
| 180 | + # No need to re-display model name here, it's already at the top of the column |
| 181 | + |
| 182 | + if "Error" in result["response"]: |
| 183 | + st.error(f"Error: {result['response']}") |
| 184 | + else: |
| 185 | + st.markdown( |
| 186 | + f""" |
| 187 | + <div style="background-color:#e6f0ff; padding:10px; border-radius:8px; margin-bottom:10px;"> |
| 188 | + <b>Duration</b>: <span style="color:#3366cc;">{result['duration']} secs</span><br> |
| 189 | + <b>Eval count</b>: <span style="color:green;">{result['eval_count']} tokens</span><br> |
| 190 | + <b>Eval rate</b>: <span style="color:green;">{result['eval_rate']} tokens/s</span> |
| 191 | + </div> |
| 192 | + """, |
| 193 | + unsafe_allow_html=True |
| 194 | + ) |
| 195 | + st.write(result["response"]) |
0 commit comments