Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ log = { workspace = true }
dirs = { workspace = true }
tempfile = { workspace = true }
config = "0.15.19"
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "674535905a966b44104a327dd1e2ca80f4b4a444" }
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "ce7feec4ae112cc9ad442841c6b49961a599580e" }
futures-util = "0.3.31"
indoc = "2.0.7"

Expand Down
31 changes: 13 additions & 18 deletions backend/src/ai/fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,23 +257,18 @@ impl Agent {
}

/// Push accumulated tool results to conversation as tool response messages.
/// Each tool response becomes a separate message with ChatRole::Tool.
fn push_tool_results_to_conversation(&mut self) {
if !self.context.tool_results.is_empty() {
let tool_result_parts = self
.context
.tool_results
.drain(..)
.map(|result| {
let result_str = match result.output {
ToolOutput::Success(s) => s,
ToolOutput::Error(e) => format!("Error: {}", e),
};
ContentPart::ToolResponse(ToolResponse::new(result.call_id, result_str))
})
.collect::<Vec<ContentPart>>();
let tool_result_content = MessageContent::from_parts(tool_result_parts);
let tool_result_message = ChatMessage::user(tool_result_content);
self.context.conversation.push(tool_result_message);
for result in self.context.tool_results.drain(..) {
let result_str = match result.output {
ToolOutput::Success(s) => s,
ToolOutput::Error(e) => format!("Error: {}", e),
};
let tool_response = ToolResponse::new(result.call_id, result_str);
// ChatMessage::from(ToolResponse) creates a message with ChatRole::Tool
self.context
.conversation
.push(ChatMessage::from(tool_response));
}
}

Expand Down Expand Up @@ -824,7 +819,7 @@ mod tests {
assert_eq!(t.effects, vec![Effect::Cancelled]);
assert!(agent.context().pending_tools.is_empty());
// Tool results were pushed to conversation as error responses
// Conversation: user msg, assistant msg, tool response (with 2 cancelled results)
assert_eq!(agent.context().conversation.len(), 3);
// Conversation: user msg, assistant msg, tool response, tool response
assert_eq!(agent.context().conversation.len(), 4);
}
}
42 changes: 26 additions & 16 deletions backend/src/ai/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ impl SessionHandle {
}

/// Send a user message to the session.
pub async fn send_user_message(&self, content: String) -> Result<(), AISessionError> {
pub async fn send_user_message(
&self,
content: String,
model: ModelSelection,
) -> Result<(), AISessionError> {
let msg = Event::ModelChange(model);
self.event_tx
.send(msg)
.await
.map_err(|_| AISessionError::ChannelClosed)?;

let msg = ChatMessage::user(content);
self.event_tx
.send(Event::UserMessage(msg))
Expand Down Expand Up @@ -204,12 +214,16 @@ fn resolve_service_target(
}
};

let auth = AuthData::Key(
AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
.await
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?,
);
service_target.auth = auth;
let key = AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
.await
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?;

if let Some(key) = key {
let auth = AuthData::Key(key);
service_target.auth = auth;
} else {
service_target.auth = AuthData::Key("".to_string());
}

let model_id = ModelIden::new(adapter_kind, parts[1]);
service_target.model = model_id;
Expand Down Expand Up @@ -354,12 +368,12 @@ impl AISession {
async fn get_api_key(
_adapter_kind: AdapterKind,
is_hub: bool,
) -> Result<String, AISessionError> {
) -> Result<Option<String>, AISessionError> {
if is_hub {
return Ok("".to_string());
return Ok(None);
}

Ok("".to_string())
Ok(None)
}

/// Run the session event loop.
Expand Down Expand Up @@ -598,14 +612,10 @@ impl AISession {
Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
log::trace!("Session {} received thought signature chunk", session_id,);
}
Ok(ChatStreamEvent::ToolCallChunk(tc_chunk)) => {
Ok(ChatStreamEvent::ToolCallChunk(_tc_chunk)) => {
// Tool call chunks are accumulated by genai internally
// We'll get the complete tool calls in the End event
log::trace!(
"Session {} received tool call chunk: {:?}",
session_id,
tc_chunk
);
log::trace!("Session {} received tool call chunk", session_id);
}
Ok(ChatStreamEvent::ReasoningChunk(_)) => {
log::trace!("Session {} received reasoning chunk", session_id);
Expand Down
2 changes: 1 addition & 1 deletion backend/src/ai/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl fmt::Display for ModelSelection {
},
ModelSelection::Ollama { model, uri } => match uri {
Some(uri) => write!(f, "ollama::{model}::{}", uri.deref()),
None => write!(f, "ollama::{model}::default"),
None => write!(f, "ollama::{model}::http://localhost:11434"),
},
}
}
Expand Down
3 changes: 2 additions & 1 deletion backend/src/commands/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,15 @@ pub async fn ai_send_message(
state: tauri::State<'_, AtuinState>,
session_id: Uuid,
message: String,
model: ModelSelection,
) -> Result<(), String> {
let sessions = state.ai_sessions.read().await;
let handle = sessions
.get(&session_id)
.ok_or_else(|| format!("Session {} not found", session_id))?;

handle
.send_user_message(message)
.send_user_message(message, model)
.await
.map_err(|e| e.to_string())
}
Expand Down
133 changes: 115 additions & 18 deletions src/components/Settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import handleDeepLink from "@/routes/root/deep";
import * as api from "@/api/api";
import InterpreterSelector from "@/lib/blocks/common/InterpreterSelector";
import AtuinEnv from "@/atuin_env";
import { OllamaSettings, useAIProviderSettings } from "@/state/settings_ai";

async function loadFonts(): Promise<string[]> {
const fonts = await invoke<string[]>("list_fonts");
Expand All @@ -49,7 +50,7 @@ async function loadFonts(): Promise<string[]> {
}

// Custom hook for managing settings
const useSettingsState = (
export const useSettingsState = (
_key: any,
initialValue: any,
settingsGetter: any,
Expand Down Expand Up @@ -109,6 +110,7 @@ interface SettingsSwitchProps {
onValueChange: (e: boolean) => void;
description: string;
className?: string;
isDisabled?: boolean;
}

const SettingSwitch = ({
Expand All @@ -117,11 +119,13 @@ const SettingSwitch = ({
onValueChange,
description,
className,
isDisabled,
}: SettingsSwitchProps) => (
<Switch
isSelected={isSelected}
onValueChange={onValueChange}
className={cn("flex justify-between items-center w-full", className)}
isDisabled={isDisabled || false}
>
<div className="flex flex-col">
<span>{label}</span>
Expand Down Expand Up @@ -1190,29 +1194,122 @@ const AISettings = () => {
const setAiEnabled = useStore((state) => state.setAiEnabled);
const setAiShareContext = useStore((state) => state.setAiShareContext);

return (
<>
<Card shadow="sm">
<CardBody className="flex flex-col gap-4 mb-4">
<h2 className="text-xl font-semibold">AI</h2>
<p className="text-sm text-default-500">
Configure AI-powered features in runbooks
</p>

<SettingSwitch
label="Enable AI features"
isSelected={aiEnabled}
onValueChange={setAiEnabled}
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K, AI Agent Sidebar)"
/>

{aiEnabled && (
<SettingSwitch
className="ml-4"
label="Share document context"
isSelected={aiShareContext}
onValueChange={setAiShareContext}
description="Send document content to improve AI suggestions. Disable for sensitive documents."
/>
)}
</CardBody>
</Card>
{aiEnabled && (
<>
<AgentSettings />
<AIOllamaSettings />
</>
)}
</>
);
};

const AgentSettings = () => {
const providers = [
["Atuin Hub", "atuinhub"],
["Ollama", "ollama"]
]

const [aiProvider, setAiProvider, aiProviderLoading] = useSettingsState(
"ai_provider",
"atuinhub",
Settings.aiAgentProvider,
Settings.aiAgentProvider,
);

const handleProviderChange = (keys: SharedSelection) => {
const key = keys.currentKey as string;
if (key) {
setAiProvider(key);
}
};

return (
<Card shadow="sm">
<CardBody className="flex flex-col gap-4">
<h2 className="text-xl font-semibold">AI</h2>
<p className="text-sm text-default-500">
Configure AI-powered features in runbooks
</p>
<h2 className="text-xl font-semibold">AI Agent</h2>

<Select
label="Default AI provider"
value={aiProvider}
onSelectionChange={handleProviderChange}
className="mt-4"
placeholder="Select default AI provider"
selectedKeys={[aiProvider]}
items={providers.map(([name, id]) => ({ label: name, key: id }))}
isDisabled={aiProviderLoading}
>
{(item) => <SelectItem key={item.key}>{item.label}</SelectItem>}
</Select>
</CardBody>
</Card>
);
};

const AIOllamaSettings = () => {
const [ollamaSettings, setOllamaSettings, isLoading] = useAIProviderSettings<OllamaSettings>("ollama", {
enabled: false,
endpoint: "http://localhost:11434",
model: "",
});

return (
<Card shadow="sm">
<CardBody className="flex flex-col gap-4">
<h2 className="text-xl font-semibold">Ollama</h2>

<SettingSwitch
label="Enable AI features"
isSelected={aiEnabled}
onValueChange={setAiEnabled}
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K)"
label="Enable Ollama AI provider"
isSelected={ollamaSettings.enabled}
onValueChange={(enabled) => setOllamaSettings({ ...ollamaSettings, enabled })}
description="Toggle to use Ollama as the AI provider."
/>

{aiEnabled && (
<SettingSwitch
className="ml-4"
label="Share document context"
isSelected={aiShareContext}
onValueChange={setAiShareContext}
description="Send document content to improve AI suggestions. Disable for sensitive documents."
/>
{ollamaSettings.enabled && (
<div className="flex flex-col gap-4">
<Input
label="Endpoint (optional, defaults to http://localhost:11434)"
placeholder="Endpoint URL (e.g. http://localhost:11434)"
value={ollamaSettings.endpoint}
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, endpoint: value })}
isDisabled={isLoading}
/>

<Input
label="Model (required; your chosen model must support tool calling)"
placeholder="Model name"
value={ollamaSettings.model}
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, model: value })}
isDisabled={isLoading}
/>
</div>
)}
</CardBody>
</Card>
Expand All @@ -1235,7 +1332,7 @@ const UserSettings = () => {
const deepLink = `atuin://register-token/${token}`;
// token submit deep link doesn't require a runbook activation,
// so passing an empty function for simplicity
handleDeepLink(deepLink, () => {});
handleDeepLink(deepLink, () => { });
}

let content;
Expand Down
24 changes: 21 additions & 3 deletions src/components/runbooks/editor/ui/AIAssistant.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import { Settings } from "@/state/settings";
import { useStore } from "@/state/store";
import { ChargeTarget } from "@/rs-bindings/ChargeTarget";
import AtuinEnv from "@/atuin_env";
import { getModelSelection } from "@/state/settings_ai";
import { DialogBuilder } from "@/components/Dialogs/dialog";

const ALL_TOOL_NAMES = [
"get_runbook_document",
Expand Down Expand Up @@ -628,11 +630,27 @@ export default function AIAssistant({
}
}, [isOpen, sessionId]);

const handleSend = useCallback(() => {
const handleSend = useCallback(async () => {
if (!inputValue.trim() || isStreaming || !sessionId) return;
// TODO: Allow buffering one message while streaming
sendMessage(inputValue.trim());

const input = inputValue.trim();
setInputValue("");

const aiProvider = await Settings.aiAgentProvider();
const modelSelection = await getModelSelection(aiProvider);
if (modelSelection.isErr()) {
const err = modelSelection.unwrapErr();
await new DialogBuilder()
.title("AI Provider Error")
.icon("error")
.message("There was an error setting up your selected AI provider: " + err)
.action({ label: "OK", value: undefined, variant: "flat" })
.build();
return;
}

// TODO: Allow buffering one message while streaming
sendMessage(input, modelSelection.unwrap());
}, [inputValue, isStreaming, sessionId, sendMessage]);

const handleKeyDown = (e: React.KeyboardEvent) => {
Expand Down
Loading
Loading