Skip to content

Commit d99b2de

Browse files
authored
feat: Allow using local Ollama models for AI assistant (#360)
1 parent 9b37e1d commit d99b2de

File tree

12 files changed

+275
-64
lines changed

12 files changed

+275
-64
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ log = { workspace = true }
8888
dirs = { workspace = true }
8989
tempfile = { workspace = true }
9090
config = "0.15.19"
91-
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "674535905a966b44104a327dd1e2ca80f4b4a444" }
91+
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "ce7feec4ae112cc9ad442841c6b49961a599580e" }
9292
futures-util = "0.3.31"
9393
indoc = "2.0.7"
9494

backend/src/ai/fsm.rs

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -257,23 +257,18 @@ impl Agent {
257257
}
258258

259259
/// Push accumulated tool results to conversation as tool response messages.
260+
/// Each tool response becomes a separate message with ChatRole::Tool.
260261
fn push_tool_results_to_conversation(&mut self) {
261-
if !self.context.tool_results.is_empty() {
262-
let tool_result_parts = self
263-
.context
264-
.tool_results
265-
.drain(..)
266-
.map(|result| {
267-
let result_str = match result.output {
268-
ToolOutput::Success(s) => s,
269-
ToolOutput::Error(e) => format!("Error: {}", e),
270-
};
271-
ContentPart::ToolResponse(ToolResponse::new(result.call_id, result_str))
272-
})
273-
.collect::<Vec<ContentPart>>();
274-
let tool_result_content = MessageContent::from_parts(tool_result_parts);
275-
let tool_result_message = ChatMessage::user(tool_result_content);
276-
self.context.conversation.push(tool_result_message);
262+
for result in self.context.tool_results.drain(..) {
263+
let result_str = match result.output {
264+
ToolOutput::Success(s) => s,
265+
ToolOutput::Error(e) => format!("Error: {}", e),
266+
};
267+
let tool_response = ToolResponse::new(result.call_id, result_str);
268+
// ChatMessage::from(ToolResponse) creates a message with ChatRole::Tool
269+
self.context
270+
.conversation
271+
.push(ChatMessage::from(tool_response));
277272
}
278273
}
279274

@@ -824,7 +819,7 @@ mod tests {
824819
assert_eq!(t.effects, vec![Effect::Cancelled]);
825820
assert!(agent.context().pending_tools.is_empty());
826821
// Tool results were pushed to conversation as error responses
827-
// Conversation: user msg, assistant msg, tool response (with 2 cancelled results)
828-
assert_eq!(agent.context().conversation.len(), 3);
822+
// Conversation: user msg, assistant msg, tool response, tool response
823+
assert_eq!(agent.context().conversation.len(), 4);
829824
}
830825
}

backend/src/ai/session.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,17 @@ impl SessionHandle {
117117
}
118118

119119
/// Send a user message to the session.
120-
pub async fn send_user_message(&self, content: String) -> Result<(), AISessionError> {
120+
pub async fn send_user_message(
121+
&self,
122+
content: String,
123+
model: ModelSelection,
124+
) -> Result<(), AISessionError> {
125+
let msg = Event::ModelChange(model);
126+
self.event_tx
127+
.send(msg)
128+
.await
129+
.map_err(|_| AISessionError::ChannelClosed)?;
130+
121131
let msg = ChatMessage::user(content);
122132
self.event_tx
123133
.send(Event::UserMessage(msg))
@@ -204,12 +214,16 @@ fn resolve_service_target(
204214
}
205215
};
206216

207-
let auth = AuthData::Key(
208-
AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
209-
.await
210-
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?,
211-
);
212-
service_target.auth = auth;
217+
let key = AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
218+
.await
219+
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?;
220+
221+
if let Some(key) = key {
222+
let auth = AuthData::Key(key);
223+
service_target.auth = auth;
224+
} else {
225+
service_target.auth = AuthData::Key("".to_string());
226+
}
213227

214228
let model_id = ModelIden::new(adapter_kind, parts[1]);
215229
service_target.model = model_id;
@@ -354,12 +368,12 @@ impl AISession {
354368
async fn get_api_key(
355369
_adapter_kind: AdapterKind,
356370
is_hub: bool,
357-
) -> Result<String, AISessionError> {
371+
) -> Result<Option<String>, AISessionError> {
358372
if is_hub {
359-
return Ok("".to_string());
373+
return Ok(None);
360374
}
361375

362-
Ok("".to_string())
376+
Ok(None)
363377
}
364378

365379
/// Run the session event loop.
@@ -598,14 +612,10 @@ impl AISession {
598612
Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
599613
log::trace!("Session {} received thought signature chunk", session_id,);
600614
}
601-
Ok(ChatStreamEvent::ToolCallChunk(tc_chunk)) => {
615+
Ok(ChatStreamEvent::ToolCallChunk(_tc_chunk)) => {
602616
// Tool call chunks are accumulated by genai internally
603617
// We'll get the complete tool calls in the End event
604-
log::trace!(
605-
"Session {} received tool call chunk: {:?}",
606-
session_id,
607-
tc_chunk
608-
);
618+
log::trace!("Session {} received tool call chunk", session_id);
609619
}
610620
Ok(ChatStreamEvent::ReasoningChunk(_)) => {
611621
log::trace!("Session {} received reasoning chunk", session_id);

backend/src/ai/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl fmt::Display for ModelSelection {
2828
},
2929
ModelSelection::Ollama { model, uri } => match uri {
3030
Some(uri) => write!(f, "ollama::{model}::{}", uri.deref()),
31-
None => write!(f, "ollama::{model}::default"),
31+
None => write!(f, "ollama::{model}::http://localhost:11434"),
3232
},
3333
}
3434
}

backend/src/commands/ai.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,15 @@ pub async fn ai_send_message(
300300
state: tauri::State<'_, AtuinState>,
301301
session_id: Uuid,
302302
message: String,
303+
model: ModelSelection,
303304
) -> Result<(), String> {
304305
let sessions = state.ai_sessions.read().await;
305306
let handle = sessions
306307
.get(&session_id)
307308
.ok_or_else(|| format!("Session {} not found", session_id))?;
308309

309310
handle
310-
.send_user_message(message)
311+
.send_user_message(message, model)
311312
.await
312313
.map_err(|e| e.to_string())
313314
}

src/components/Settings/Settings.tsx

Lines changed: 115 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import handleDeepLink from "@/routes/root/deep";
3939
import * as api from "@/api/api";
4040
import InterpreterSelector from "@/lib/blocks/common/InterpreterSelector";
4141
import AtuinEnv from "@/atuin_env";
42+
import { OllamaSettings, useAIProviderSettings } from "@/state/settings_ai";
4243

4344
async function loadFonts(): Promise<string[]> {
4445
const fonts = await invoke<string[]>("list_fonts");
@@ -49,7 +50,7 @@ async function loadFonts(): Promise<string[]> {
4950
}
5051

5152
// Custom hook for managing settings
52-
const useSettingsState = (
53+
export const useSettingsState = (
5354
_key: any,
5455
initialValue: any,
5556
settingsGetter: any,
@@ -109,6 +110,7 @@ interface SettingsSwitchProps {
109110
onValueChange: (e: boolean) => void;
110111
description: string;
111112
className?: string;
113+
isDisabled?: boolean;
112114
}
113115

114116
const SettingSwitch = ({
@@ -117,11 +119,13 @@ const SettingSwitch = ({
117119
onValueChange,
118120
description,
119121
className,
122+
isDisabled,
120123
}: SettingsSwitchProps) => (
121124
<Switch
122125
isSelected={isSelected}
123126
onValueChange={onValueChange}
124127
className={cn("flex justify-between items-center w-full", className)}
128+
isDisabled={isDisabled || false}
125129
>
126130
<div className="flex flex-col">
127131
<span>{label}</span>
@@ -1190,29 +1194,122 @@ const AISettings = () => {
11901194
const setAiEnabled = useStore((state) => state.setAiEnabled);
11911195
const setAiShareContext = useStore((state) => state.setAiShareContext);
11921196

1197+
return (
1198+
<>
1199+
<Card shadow="sm">
1200+
<CardBody className="flex flex-col gap-4 mb-4">
1201+
<h2 className="text-xl font-semibold">AI</h2>
1202+
<p className="text-sm text-default-500">
1203+
Configure AI-powered features in runbooks
1204+
</p>
1205+
1206+
<SettingSwitch
1207+
label="Enable AI features"
1208+
isSelected={aiEnabled}
1209+
onValueChange={setAiEnabled}
1210+
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K, AI Agent Sidebar)"
1211+
/>
1212+
1213+
{aiEnabled && (
1214+
<SettingSwitch
1215+
className="ml-4"
1216+
label="Share document context"
1217+
isSelected={aiShareContext}
1218+
onValueChange={setAiShareContext}
1219+
description="Send document content to improve AI suggestions. Disable for sensitive documents."
1220+
/>
1221+
)}
1222+
</CardBody>
1223+
</Card>
1224+
{aiEnabled && (
1225+
<>
1226+
<AgentSettings />
1227+
<AIOllamaSettings />
1228+
</>
1229+
)}
1230+
</>
1231+
);
1232+
};
1233+
1234+
const AgentSettings = () => {
1235+
const providers = [
1236+
["Atuin Hub", "atuinhub"],
1237+
["Ollama", "ollama"]
1238+
]
1239+
1240+
const [aiProvider, setAiProvider, aiProviderLoading] = useSettingsState(
1241+
"ai_provider",
1242+
"atuinhub",
1243+
Settings.aiAgentProvider,
1244+
Settings.aiAgentProvider,
1245+
);
1246+
1247+
const handleProviderChange = (keys: SharedSelection) => {
1248+
const key = keys.currentKey as string;
1249+
if (key) {
1250+
setAiProvider(key);
1251+
}
1252+
};
1253+
11931254
return (
11941255
<Card shadow="sm">
11951256
<CardBody className="flex flex-col gap-4">
1196-
<h2 className="text-xl font-semibold">AI</h2>
1197-
<p className="text-sm text-default-500">
1198-
Configure AI-powered features in runbooks
1199-
</p>
1257+
<h2 className="text-xl font-semibold">AI Agent</h2>
1258+
1259+
<Select
1260+
label="Default AI provider"
1261+
value={aiProvider}
1262+
onSelectionChange={handleProviderChange}
1263+
className="mt-4"
1264+
placeholder="Select default AI provider"
1265+
selectedKeys={[aiProvider]}
1266+
items={providers.map(([name, id]) => ({ label: name, key: id }))}
1267+
isDisabled={aiProviderLoading}
1268+
>
1269+
{(item) => <SelectItem key={item.key}>{item.label}</SelectItem>}
1270+
</Select>
1271+
</CardBody>
1272+
</Card>
1273+
);
1274+
};
1275+
1276+
const AIOllamaSettings = () => {
1277+
const [ollamaSettings, setOllamaSettings, isLoading] = useAIProviderSettings<OllamaSettings>("ollama", {
1278+
enabled: false,
1279+
endpoint: "http://localhost:11434",
1280+
model: "",
1281+
});
1282+
1283+
return (
1284+
<Card shadow="sm">
1285+
<CardBody className="flex flex-col gap-4">
1286+
<h2 className="text-xl font-semibold">Ollama</h2>
12001287

12011288
<SettingSwitch
1202-
label="Enable AI features"
1203-
isSelected={aiEnabled}
1204-
onValueChange={setAiEnabled}
1205-
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K)"
1289+
label="Enable Ollama AI provider"
1290+
isSelected={ollamaSettings.enabled}
1291+
onValueChange={(enabled) => setOllamaSettings({ ...ollamaSettings, enabled })}
1292+
description="Toggle to use Ollama as the AI provider."
12061293
/>
12071294

1208-
{aiEnabled && (
1209-
<SettingSwitch
1210-
className="ml-4"
1211-
label="Share document context"
1212-
isSelected={aiShareContext}
1213-
onValueChange={setAiShareContext}
1214-
description="Send document content to improve AI suggestions. Disable for sensitive documents."
1215-
/>
1295+
{ollamaSettings.enabled && (
1296+
<div className="flex flex-col gap-4">
1297+
<Input
1298+
label="Endpoint (optional, defaults to http://localhost:11434)"
1299+
placeholder="Endpoint URL (e.g. http://localhost:11434)"
1300+
value={ollamaSettings.endpoint}
1301+
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, endpoint: value })}
1302+
isDisabled={isLoading}
1303+
/>
1304+
1305+
<Input
1306+
label="Model (required; your chosen model must support tool calling)"
1307+
placeholder="Model name"
1308+
value={ollamaSettings.model}
1309+
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, model: value })}
1310+
isDisabled={isLoading}
1311+
/>
1312+
</div>
12161313
)}
12171314
</CardBody>
12181315
</Card>
@@ -1235,7 +1332,7 @@ const UserSettings = () => {
12351332
const deepLink = `atuin://register-token/${token}`;
12361333
// token submit deep link doesn't require a runbook activation,
12371334
// so passing an empty function for simplicity
1238-
handleDeepLink(deepLink, () => {});
1335+
handleDeepLink(deepLink, () => { });
12391336
}
12401337

12411338
let content;

src/components/runbooks/editor/ui/AIAssistant.tsx

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ import { Settings } from "@/state/settings";
4646
import { useStore } from "@/state/store";
4747
import { ChargeTarget } from "@/rs-bindings/ChargeTarget";
4848
import AtuinEnv from "@/atuin_env";
49+
import { getModelSelection } from "@/state/settings_ai";
50+
import { DialogBuilder } from "@/components/Dialogs/dialog";
4951

5052
const ALL_TOOL_NAMES = [
5153
"get_runbook_document",
@@ -628,11 +630,27 @@ export default function AIAssistant({
628630
}
629631
}, [isOpen, sessionId]);
630632

631-
const handleSend = useCallback(() => {
633+
const handleSend = useCallback(async () => {
632634
if (!inputValue.trim() || isStreaming || !sessionId) return;
633-
// TODO: Allow buffering one message while streaming
634-
sendMessage(inputValue.trim());
635+
636+
const input = inputValue.trim();
635637
setInputValue("");
638+
639+
const aiProvider = await Settings.aiAgentProvider();
640+
const modelSelection = await getModelSelection(aiProvider);
641+
if (modelSelection.isErr()) {
642+
const err = modelSelection.unwrapErr();
643+
await new DialogBuilder()
644+
.title("AI Provider Error")
645+
.icon("error")
646+
.message("There was an error setting up your selected AI provider: " + err)
647+
.action({ label: "OK", value: undefined, variant: "flat" })
648+
.build();
649+
return;
650+
}
651+
652+
// TODO: Allow buffering one message while streaming
653+
sendMessage(input, modelSelection.unwrap());
636654
}, [inputValue, isStreaming, sessionId, sendMessage]);
637655

638656
const handleKeyDown = (e: React.KeyboardEvent) => {

0 commit comments

Comments
 (0)