diff --git a/Cargo.lock b/Cargo.lock index 6df38e98..884ae4b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3213,7 +3213,7 @@ dependencies = [ [[package]] name = "genai" version = "0.5.0-alpha.10-WIP" -source = "git+https://github.com/BinaryMuse/rust-genai?rev=674535905a966b44104a327dd1e2ca80f4b4a444#674535905a966b44104a327dd1e2ca80f4b4a444" +source = "git+https://github.com/BinaryMuse/rust-genai?rev=ce7feec4ae112cc9ad442841c6b49961a599580e#ce7feec4ae112cc9ad442841c6b49961a599580e" dependencies = [ "base64 0.22.1", "bytes", diff --git a/backend/Cargo.toml b/backend/Cargo.toml index a70ef8e0..583bbd83 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -88,7 +88,7 @@ log = { workspace = true } dirs = { workspace = true } tempfile = { workspace = true } config = "0.15.19" -genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "674535905a966b44104a327dd1e2ca80f4b4a444" } +genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "ce7feec4ae112cc9ad442841c6b49961a599580e" } futures-util = "0.3.31" indoc = "2.0.7" diff --git a/backend/src/ai/fsm.rs b/backend/src/ai/fsm.rs index 918b7801..7e2771f5 100644 --- a/backend/src/ai/fsm.rs +++ b/backend/src/ai/fsm.rs @@ -257,23 +257,18 @@ impl Agent { } /// Push accumulated tool results to conversation as tool response messages. + /// Each tool response becomes a separate message with ChatRole::Tool. fn push_tool_results_to_conversation(&mut self) { - if !self.context.tool_results.is_empty() { - let tool_result_parts = self - .context - .tool_results - .drain(..) - .map(|result| { - let result_str = match result.output { - ToolOutput::Success(s) => s, - ToolOutput::Error(e) => format!("Error: {}", e), - }; - ContentPart::ToolResponse(ToolResponse::new(result.call_id, result_str)) - }) - .collect::>(); - let tool_result_content = MessageContent::from_parts(tool_result_parts); - let tool_result_message = ChatMessage::user(tool_result_content); - self.context.conversation.push(tool_result_message); + for result in self.context.tool_results.drain(..) { + let result_str = match result.output { + ToolOutput::Success(s) => s, + ToolOutput::Error(e) => format!("Error: {}", e), + }; + let tool_response = ToolResponse::new(result.call_id, result_str); + // ChatMessage::from(ToolResponse) creates a message with ChatRole::Tool + self.context + .conversation + .push(ChatMessage::from(tool_response)); } } @@ -824,7 +819,7 @@ mod tests { assert_eq!(t.effects, vec![Effect::Cancelled]); assert!(agent.context().pending_tools.is_empty()); // Tool results were pushed to conversation as error responses - // Conversation: user msg, assistant msg, tool response (with 2 cancelled results) - assert_eq!(agent.context().conversation.len(), 3); + // Conversation: user msg, assistant msg, tool response, tool response + assert_eq!(agent.context().conversation.len(), 4); } } diff --git a/backend/src/ai/session.rs b/backend/src/ai/session.rs index f1308cd4..2d3cbc35 100644 --- a/backend/src/ai/session.rs +++ b/backend/src/ai/session.rs @@ -117,7 +117,17 @@ impl SessionHandle { } /// Send a user message to the session. - pub async fn send_user_message(&self, content: String) -> Result<(), AISessionError> { + pub async fn send_user_message( + &self, + content: String, + model: ModelSelection, + ) -> Result<(), AISessionError> { + let msg = Event::ModelChange(model); + self.event_tx + .send(msg) + .await + .map_err(|_| AISessionError::ChannelClosed)?; + let msg = ChatMessage::user(content); self.event_tx .send(Event::UserMessage(msg)) @@ -204,12 +214,16 @@ fn resolve_service_target( } }; - let auth = AuthData::Key( - AISession::get_api_key(adapter_kind, parts[0] == "atuinhub") - .await - .map_err(|e| genai::resolver::Error::Custom(e.to_string()))?, - ); - service_target.auth = auth; + let key = AISession::get_api_key(adapter_kind, parts[0] == "atuinhub") + .await + .map_err(|e| genai::resolver::Error::Custom(e.to_string()))?; + + if let Some(key) = key { + let auth = AuthData::Key(key); + service_target.auth = auth; + } else { + service_target.auth = AuthData::Key("".to_string()); + } let model_id = ModelIden::new(adapter_kind, parts[1]); service_target.model = model_id; @@ -354,12 +368,12 @@ impl AISession { async fn get_api_key( _adapter_kind: AdapterKind, is_hub: bool, - ) -> Result { + ) -> Result, AISessionError> { if is_hub { - return Ok("".to_string()); + return Ok(None); } - Ok("".to_string()) + Ok(None) } /// Run the session event loop. @@ -598,14 +612,10 @@ impl AISession { Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => { log::trace!("Session {} received thought signature chunk", session_id,); } - Ok(ChatStreamEvent::ToolCallChunk(tc_chunk)) => { + Ok(ChatStreamEvent::ToolCallChunk(_tc_chunk)) => { // Tool call chunks are accumulated by genai internally // We'll get the complete tool calls in the End event - log::trace!( - "Session {} received tool call chunk: {:?}", - session_id, - tc_chunk - ); + log::trace!("Session {} received tool call chunk", session_id); } Ok(ChatStreamEvent::ReasoningChunk(_)) => { log::trace!("Session {} received reasoning chunk", session_id); diff --git a/backend/src/ai/types.rs b/backend/src/ai/types.rs index ccbc0621..48d8a926 100644 --- a/backend/src/ai/types.rs +++ b/backend/src/ai/types.rs @@ -28,7 +28,7 @@ impl fmt::Display for ModelSelection { }, ModelSelection::Ollama { model, uri } => match uri { Some(uri) => write!(f, "ollama::{model}::{}", uri.deref()), - None => write!(f, "ollama::{model}::default"), + None => write!(f, "ollama::{model}::http://localhost:11434"), }, } } diff --git a/backend/src/commands/ai.rs b/backend/src/commands/ai.rs index 988fa5d3..4d265753 100644 --- a/backend/src/commands/ai.rs +++ b/backend/src/commands/ai.rs @@ -300,6 +300,7 @@ pub async fn ai_send_message( state: tauri::State<'_, AtuinState>, session_id: Uuid, message: String, + model: ModelSelection, ) -> Result<(), String> { let sessions = state.ai_sessions.read().await; let handle = sessions @@ -307,7 +308,7 @@ pub async fn ai_send_message( .ok_or_else(|| format!("Session {} not found", session_id))?; handle - .send_user_message(message) + .send_user_message(message, model) .await .map_err(|e| e.to_string()) } diff --git a/src/components/Settings/Settings.tsx b/src/components/Settings/Settings.tsx index b3ab365e..13fc05d8 100644 --- a/src/components/Settings/Settings.tsx +++ b/src/components/Settings/Settings.tsx @@ -39,6 +39,7 @@ import handleDeepLink from "@/routes/root/deep"; import * as api from "@/api/api"; import InterpreterSelector from "@/lib/blocks/common/InterpreterSelector"; import AtuinEnv from "@/atuin_env"; +import { OllamaSettings, useAIProviderSettings } from "@/state/settings_ai"; async function loadFonts(): Promise { const fonts = await invoke("list_fonts"); @@ -49,7 +50,7 @@ async function loadFonts(): Promise { } // Custom hook for managing settings -const useSettingsState = ( +export const useSettingsState = ( _key: any, initialValue: any, settingsGetter: any, @@ -109,6 +110,7 @@ interface SettingsSwitchProps { onValueChange: (e: boolean) => void; description: string; className?: string; + isDisabled?: boolean; } const SettingSwitch = ({ @@ -117,11 +119,13 @@ const SettingSwitch = ({ onValueChange, description, className, + isDisabled, }: SettingsSwitchProps) => (
{label} @@ -1190,29 +1194,122 @@ const AISettings = () => { const setAiEnabled = useStore((state) => state.setAiEnabled); const setAiShareContext = useStore((state) => state.setAiShareContext); + return ( + <> + + +

AI

+

+ Configure AI-powered features in runbooks +

+ + + + {aiEnabled && ( + + )} +
+
+ {aiEnabled && ( + <> + + + + )} + + ); +}; + +const AgentSettings = () => { + const providers = [ + ["Atuin Hub", "atuinhub"], + ["Ollama", "ollama"] + ] + + const [aiProvider, setAiProvider, aiProviderLoading] = useSettingsState( + "ai_provider", + "atuinhub", + Settings.aiAgentProvider, + Settings.aiAgentProvider, + ); + + const handleProviderChange = (keys: SharedSelection) => { + const key = keys.currentKey as string; + if (key) { + setAiProvider(key); + } + }; + return ( -

AI

-

- Configure AI-powered features in runbooks -

+

AI Agent

+ + +
+
+ ); +}; + +const AIOllamaSettings = () => { + const [ollamaSettings, setOllamaSettings, isLoading] = useAIProviderSettings("ollama", { + enabled: false, + endpoint: "http://localhost:11434", + model: "", + }); + + return ( + + +

Ollama

setOllamaSettings({ ...ollamaSettings, enabled })} + description="Toggle to use Ollama as the AI provider." /> - {aiEnabled && ( - + {ollamaSettings.enabled && ( +
+ setOllamaSettings({ ...ollamaSettings, endpoint: value })} + isDisabled={isLoading} + /> + + setOllamaSettings({ ...ollamaSettings, model: value })} + isDisabled={isLoading} + /> +
)}
@@ -1235,7 +1332,7 @@ const UserSettings = () => { const deepLink = `atuin://register-token/${token}`; // token submit deep link doesn't require a runbook activation, // so passing an empty function for simplicity - handleDeepLink(deepLink, () => {}); + handleDeepLink(deepLink, () => { }); } let content; diff --git a/src/components/runbooks/editor/ui/AIAssistant.tsx b/src/components/runbooks/editor/ui/AIAssistant.tsx index c6396435..bc0090e0 100644 --- a/src/components/runbooks/editor/ui/AIAssistant.tsx +++ b/src/components/runbooks/editor/ui/AIAssistant.tsx @@ -46,6 +46,8 @@ import { Settings } from "@/state/settings"; import { useStore } from "@/state/store"; import { ChargeTarget } from "@/rs-bindings/ChargeTarget"; import AtuinEnv from "@/atuin_env"; +import { getModelSelection } from "@/state/settings_ai"; +import { DialogBuilder } from "@/components/Dialogs/dialog"; const ALL_TOOL_NAMES = [ "get_runbook_document", @@ -628,11 +630,27 @@ export default function AIAssistant({ } }, [isOpen, sessionId]); - const handleSend = useCallback(() => { + const handleSend = useCallback(async () => { if (!inputValue.trim() || isStreaming || !sessionId) return; - // TODO: Allow buffering one message while streaming - sendMessage(inputValue.trim()); + + const input = inputValue.trim(); setInputValue(""); + + const aiProvider = await Settings.aiAgentProvider(); + const modelSelection = await getModelSelection(aiProvider); + if (modelSelection.isErr()) { + const err = modelSelection.unwrapErr(); + await new DialogBuilder() + .title("AI Provider Error") + .icon("error") + .message("There was an error setting up your selected AI provider: " + err) + .action({ label: "OK", value: undefined, variant: "flat" }) + .build(); + return; + } + + // TODO: Allow buffering one message while streaming + sendMessage(input, modelSelection.unwrap()); }, [inputValue, isStreaming, sessionId, sendMessage]); const handleKeyDown = (e: React.KeyboardEvent) => { diff --git a/src/lib/ai/commands.ts b/src/lib/ai/commands.ts index 83ef10fc..da6afd9f 100644 --- a/src/lib/ai/commands.ts +++ b/src/lib/ai/commands.ts @@ -70,8 +70,8 @@ export async function changeUser(sessionId: string, user: string): Promise /** * Send a user message to an AI session. */ -export async function sendMessage(sessionId: string, message: string): Promise { - await invoke("ai_send_message", { sessionId, message }); +export async function sendMessage(sessionId: string, message: string, model?: ModelSelection): Promise { + await invoke("ai_send_message", { sessionId, message, model }); } /** diff --git a/src/lib/ai/useAIChat.ts b/src/lib/ai/useAIChat.ts index 76aa7bce..26ab2fe9 100644 --- a/src/lib/ai/useAIChat.ts +++ b/src/lib/ai/useAIChat.ts @@ -9,6 +9,7 @@ import { } from "./commands"; import { useCallback, useEffect, useMemo, useState } from "react"; import { State } from "@/rs-bindings/State"; +import { ModelSelection } from "@/rs-bindings/ModelSelection"; export interface AIChatAPI { sessionId: string; @@ -19,7 +20,7 @@ export interface AIChatAPI { pendingToolCalls: Array; error: string | null; state: State["type"]; - sendMessage: (message: string) => void; + sendMessage: (message: string, model?: ModelSelection) => void; addToolOutput: (output: AIToolOutput) => void; cancel: () => void; } @@ -153,7 +154,7 @@ export default function useAIChat(sessionId: string): AIChatAPI { }, [sessionId]); const sendMessage = useCallback( - async (message: string) => { + async (message: string, model?: ModelSelection) => { const userMessage: AIMessage = { role: "user", content: { parts: [{ type: "text", data: message }] }, @@ -168,7 +169,7 @@ export default function useAIChat(sessionId: string): AIChatAPI { } setError(null); - await sendMessageCommand(sessionId, message); + await sendMessageCommand(sessionId, message, model); }, [sessionId, isIdle], ); diff --git a/src/state/settings.ts b/src/state/settings.ts index 76c09470..b9cdfc78 100644 --- a/src/state/settings.ts +++ b/src/state/settings.ts @@ -40,6 +40,8 @@ const NOTIFICATIONS_SERIAL_PAUSED_DURATION = "settings.notifications.serial.paus const NOTIFICATIONS_SERIAL_PAUSED_SOUND = "settings.notifications.serial.paused.sound"; const NOTIFICATIONS_SERIAL_PAUSED_OS = "settings.notifications.serial.paused.os"; +const AI_AGENT_PROVIDER = "settings.ai.agent.provider"; + export class Settings { public static DEFAULT_FONT = "FiraCode"; public static DEFAULT_FONT_SIZE = 14; @@ -407,4 +409,26 @@ export class Settings { } return await this.getSystemDefaultShell(); } + + public static async aiAgentProvider(val: string | null = null): Promise { + let store = await KVStore.open_default(); + + if (val !== null) { + await store.set(AI_AGENT_PROVIDER, val); + return val; + } + + return await store.get(AI_AGENT_PROVIDER) ?? "atuinhub"; + } + + public static async aiProviderSettings>(provider: string, settings: T | null = null): Promise { + let store = await KVStore.open_default(); + + if (settings !== null) { + await store.set(`ai.provider.${provider}.settings`, settings); + return settings; + } + + return await store.get(`ai.provider.${provider}.settings`) ?? {} as T; + } } diff --git a/src/state/settings_ai.ts b/src/state/settings_ai.ts new file mode 100644 index 00000000..61f63665 --- /dev/null +++ b/src/state/settings_ai.ts @@ -0,0 +1,65 @@ +import AtuinEnv from "@/atuin_env"; +import { useSettingsState } from "@/components/Settings/Settings"; +import { ModelSelection } from "@/rs-bindings/ModelSelection"; +import { Settings } from "@/state/settings"; + +export interface OllamaSettings { + enabled: boolean; + endpoint: string; + model: string; +} + +export function useAIProviderSettings>(provider: string, defaultValue: T): [T, (settings: T) => void, boolean] { + const [settings, setSettings, isLoading] = useSettingsState( + `ai.provider.${provider}.settings`, + defaultValue as T, + () => Settings.aiProviderSettings(provider), + (settings: T) => Settings.aiProviderSettings(provider, settings), + ); + return [settings, setSettings, isLoading]; +}; + +export async function getAIProviderSettings>(provider: string): Promise { + const value = await Settings.aiProviderSettings(provider); + return value as T; +} + +export async function getModelSelection(provider: string): Promise> { + if (provider === "atuinhub") { + return Ok({ + type: "atuinHub", + data: { + model: "claude-opus-4-5-20251101", + uri: AtuinEnv.url("/api/ai/proxy/"), + } + }) as Result + } else if (provider === "ollama") { + const settings = await getAIProviderSettings("ollama"); + if (!settings.enabled) { + return Err("Ollama is not enabled in settings"); + } + if (!settings.model) { + return Err("Ollama model is not set in settings"); + } + + return Ok({ + type: "ollama", + data: { + model: settings.model, + uri: joinUrlParts([settings.endpoint, "v1/"]), + } + }) as Result + } else { + return Ok({ + type: "atuinHub", + data: { + model: "claude-opus-4-5-20251101", + uri: AtuinEnv.url("/api/ai/proxy/"), + } + }) as Result + } +} + +function joinUrlParts(parts: string[]): string { + return parts.map(p => p.replace(/\/+$/, '')).join('/').replace(/([^:]\/)\/+/g, '$1'); +}