Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ log = { workspace = true }
dirs = { workspace = true }
tempfile = { workspace = true }
config = "0.15.19"
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "674535905a966b44104a327dd1e2ca80f4b4a444" }
genai = { git = "https://github.com/BinaryMuse/rust-genai", rev = "ce7feec4ae112cc9ad442841c6b49961a599580e" }
futures-util = "0.3.31"
indoc = "2.0.7"

Expand Down
31 changes: 13 additions & 18 deletions backend/src/ai/fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,23 +257,18 @@ impl Agent {
}

/// Push accumulated tool results to conversation as tool response messages.
/// Each tool response becomes a separate message with ChatRole::Tool.
fn push_tool_results_to_conversation(&mut self) {
if !self.context.tool_results.is_empty() {
let tool_result_parts = self
.context
.tool_results
.drain(..)
.map(|result| {
let result_str = match result.output {
ToolOutput::Success(s) => s,
ToolOutput::Error(e) => format!("Error: {}", e),
};
ContentPart::ToolResponse(ToolResponse::new(result.call_id, result_str))
})
.collect::<Vec<ContentPart>>();
let tool_result_content = MessageContent::from_parts(tool_result_parts);
let tool_result_message = ChatMessage::user(tool_result_content);
self.context.conversation.push(tool_result_message);
for result in self.context.tool_results.drain(..) {
let result_str = match result.output {
ToolOutput::Success(s) => s,
ToolOutput::Error(e) => format!("Error: {}", e),
};
let tool_response = ToolResponse::new(result.call_id, result_str);
// ChatMessage::from(ToolResponse) creates a message with ChatRole::Tool
self.context
.conversation
.push(ChatMessage::from(tool_response));
}
}

Expand Down Expand Up @@ -824,7 +819,7 @@ mod tests {
assert_eq!(t.effects, vec![Effect::Cancelled]);
assert!(agent.context().pending_tools.is_empty());
// Tool results were pushed to conversation as error responses
// Conversation: user msg, assistant msg, tool response (with 2 cancelled results)
assert_eq!(agent.context().conversation.len(), 3);
// Conversation: user msg, assistant msg, tool response, tool response
assert_eq!(agent.context().conversation.len(), 4);
}
}
41 changes: 25 additions & 16 deletions backend/src/ai/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,17 @@ impl SessionHandle {
}

/// Send a user message to the session.
pub async fn send_user_message(&self, content: String) -> Result<(), AISessionError> {
pub async fn send_user_message(
&self,
content: String,
model: ModelSelection,
) -> Result<(), AISessionError> {
let msg = Event::ModelChange(model);
self.event_tx
.send(msg)
.await
.map_err(|_| AISessionError::ChannelClosed)?;

let msg = ChatMessage::user(content);
self.event_tx
.send(Event::UserMessage(msg))
Expand Down Expand Up @@ -204,12 +214,16 @@ fn resolve_service_target(
}
};

let auth = AuthData::Key(
AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
.await
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?,
);
service_target.auth = auth;
let key = AISession::get_api_key(adapter_kind, parts[0] == "atuinhub")
.await
.map_err(|e| genai::resolver::Error::Custom(e.to_string()))?;

if let Some(key) = key {
let auth = AuthData::Key(key);
service_target.auth = auth;
} else {
service_target.auth = AuthData::Key("".to_string());
}

let model_id = ModelIden::new(adapter_kind, parts[1]);
service_target.model = model_id;
Expand Down Expand Up @@ -354,12 +368,12 @@ impl AISession {
async fn get_api_key(
_adapter_kind: AdapterKind,
is_hub: bool,
) -> Result<String, AISessionError> {
) -> Result<Option<String>, AISessionError> {
if is_hub {
return Ok("".to_string());
return Ok(None);
}

Ok("".to_string())
Ok(None)
}

/// Run the session event loop.
Expand Down Expand Up @@ -598,14 +612,9 @@ impl AISession {
Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
log::trace!("Session {} received thought signature chunk", session_id,);
}
Ok(ChatStreamEvent::ToolCallChunk(tc_chunk)) => {
Ok(ChatStreamEvent::ToolCallChunk(_tc_chunk)) => {
// Tool call chunks are accumulated by genai internally
// We'll get the complete tool calls in the End event
log::trace!(
"Session {} received tool call chunk: {:?}",
session_id,
tc_chunk
);
}
Ok(ChatStreamEvent::ReasoningChunk(_)) => {
log::trace!("Session {} received reasoning chunk", session_id);
Expand Down
2 changes: 1 addition & 1 deletion backend/src/ai/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl fmt::Display for ModelSelection {
},
ModelSelection::Ollama { model, uri } => match uri {
Some(uri) => write!(f, "ollama::{model}::{}", uri.deref()),
None => write!(f, "ollama::{model}::default"),
None => write!(f, "ollama::{model}::http://localhost:11434/v1"),
},
}
}
Expand Down
3 changes: 2 additions & 1 deletion backend/src/commands/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,15 @@ pub async fn ai_send_message(
state: tauri::State<'_, AtuinState>,
session_id: Uuid,
message: String,
model: ModelSelection,
) -> Result<(), String> {
let sessions = state.ai_sessions.read().await;
let handle = sessions
.get(&session_id)
.ok_or_else(|| format!("Session {} not found", session_id))?;

handle
.send_user_message(message)
.send_user_message(message, model)
.await
.map_err(|e| e.to_string())
}
Expand Down
131 changes: 114 additions & 17 deletions src/components/Settings/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import handleDeepLink from "@/routes/root/deep";
import * as api from "@/api/api";
import InterpreterSelector from "@/lib/blocks/common/InterpreterSelector";
import AtuinEnv from "@/atuin_env";
import { OllamaSettings, useAIProviderSettings } from "@/state/settings_ai";

async function loadFonts(): Promise<string[]> {
const fonts = await invoke<string[]>("list_fonts");
Expand All @@ -49,7 +50,7 @@ async function loadFonts(): Promise<string[]> {
}

// Custom hook for managing settings
const useSettingsState = (
export const useSettingsState = (
_key: any,
initialValue: any,
settingsGetter: any,
Expand Down Expand Up @@ -109,6 +110,7 @@ interface SettingsSwitchProps {
onValueChange: (e: boolean) => void;
description: string;
className?: string;
isDisabled?: boolean;
}

const SettingSwitch = ({
Expand All @@ -117,11 +119,13 @@ const SettingSwitch = ({
onValueChange,
description,
className,
isDisabled,
}: SettingsSwitchProps) => (
<Switch
isSelected={isSelected}
onValueChange={onValueChange}
className={cn("flex justify-between items-center w-full", className)}
isDisabled={isDisabled || false}
>
<div className="flex flex-col">
<span>{label}</span>
Expand Down Expand Up @@ -1190,29 +1194,122 @@ const AISettings = () => {
const setAiEnabled = useStore((state) => state.setAiEnabled);
const setAiShareContext = useStore((state) => state.setAiShareContext);

return (
<>
<Card shadow="sm">
<CardBody className="flex flex-col gap-4 mb-4">
<h2 className="text-xl font-semibold">AI</h2>
<p className="text-sm text-default-500">
Configure AI-powered features in runbooks
</p>

<SettingSwitch
label="Enable AI features"
isSelected={aiEnabled}
onValueChange={setAiEnabled}
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K, AI Agent Sidebar)"
/>

{aiEnabled && (
<SettingSwitch
className="ml-4"
label="Share document context"
isSelected={aiShareContext}
onValueChange={setAiShareContext}
description="Send document content to improve AI suggestions. Disable for sensitive documents."
/>
)}
</CardBody>
</Card>
{aiEnabled && (
<>
<AgentSettings />
<AIOllamaSettings />
</>
)}
</>
);
};

const AgentSettings = () => {
const providers = [
["Atuin Hub", "atuinhub"],
["Ollama", "ollama"]
]

const [aiProvider, setAiProvider, aiProviderLoading] = useSettingsState(
"ai_provider",
"atuinhub",
Settings.aiAgentProvider,
Settings.aiAgentProvider,
);

const handleProviderChange = (keys: SharedSelection) => {
const key = keys.currentKey as string;
if (key) {
setAiProvider(key);
}
};

return (
<Card shadow="sm">
<CardBody className="flex flex-col gap-4">
<h2 className="text-xl font-semibold">AI</h2>
<p className="text-sm text-default-500">
Configure AI-powered features in runbooks
</p>
<h2 className="text-xl font-semibold">AI Agent</h2>

<Select
label="Provider"
value={aiProvider}
onSelectionChange={handleProviderChange}
className="mt-4"
placeholder="Select provider"
selectedKeys={[aiProvider]}
items={providers.map(([name, id]) => ({ label: name, key: id }))}
isDisabled={aiProviderLoading}
>
{(item) => <SelectItem key={item.key}>{item.label}</SelectItem>}
</Select>
</CardBody>
</Card>
);
};

const AIOllamaSettings = () => {
const [ollamaSettings, setOllamaSettings, isLoading] = useAIProviderSettings<OllamaSettings>("ollama", {
enabled: false,
endpoint: "http://localhost:11434",
model: "",
});

return (
<Card shadow="sm">
<CardBody className="flex flex-col gap-4">
<h2 className="text-xl font-semibold">Ollama</h2>

<SettingSwitch
label="Enable AI features"
isSelected={aiEnabled}
onValueChange={setAiEnabled}
description="Enable AI block generation and editing (Cmd+Enter, Cmd+K)"
label="Enable Ollama AI provider"
isSelected={ollamaSettings.enabled}
onValueChange={(enabled) => setOllamaSettings({ ...ollamaSettings, enabled })}
description="Toggle to use Ollama as the AI provider."
/>

{aiEnabled && (
<SettingSwitch
className="ml-4"
label="Share document context"
isSelected={aiShareContext}
onValueChange={setAiShareContext}
description="Send document content to improve AI suggestions. Disable for sensitive documents."
/>
{ollamaSettings.enabled && (
<div className="flex flex-col gap-4">
<Input
label="Endpoint"
placeholder="Endpoint URL (e.g. http://localhost:11434)"
value={ollamaSettings.endpoint}
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, endpoint: value })}
isDisabled={isLoading}
/>

<Input
label="Model"
placeholder="Model name"
value={ollamaSettings.model}
onValueChange={(value) => setOllamaSettings({ ...ollamaSettings, model: value })}
isDisabled={isLoading}
/>
</div>
)}
</CardBody>
</Card>
Expand Down
21 changes: 18 additions & 3 deletions src/components/runbooks/editor/ui/AIAssistant.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import { Settings } from "@/state/settings";
import { useStore } from "@/state/store";
import { ChargeTarget } from "@/rs-bindings/ChargeTarget";
import AtuinEnv from "@/atuin_env";
import { getModelSelection } from "@/state/settings_ai";
import { DialogBuilder } from "@/components/Dialogs/dialog";

const ALL_TOOL_NAMES = [
"get_runbook_document",
Expand Down Expand Up @@ -628,11 +630,24 @@ export default function AIAssistant({
}
}, [isOpen, sessionId]);

const handleSend = useCallback(() => {
const handleSend = useCallback(async () => {
if (!inputValue.trim() || isStreaming || !sessionId) return;
// TODO: Allow buffering one message while streaming
sendMessage(inputValue.trim());

const input = inputValue.trim();
setInputValue("");

const aiProvider = await Settings.aiAgentProvider();
const modelSelection = await getModelSelection(aiProvider);
if (modelSelection.isNone()) {
await new DialogBuilder().title("AI Provider Not Configured")
.message("Please configure your selected AI provider in the settings.")
.action({ label: "OK", value: undefined, variant: "flat" })
.build();
return;
}

// TODO: Allow buffering one message while streaming
sendMessage(input, modelSelection.unwrap());
}, [inputValue, isStreaming, sessionId, sendMessage]);

const handleKeyDown = (e: React.KeyboardEvent) => {
Expand Down
4 changes: 2 additions & 2 deletions src/lib/ai/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ export async function changeUser(sessionId: string, user: string): Promise<void>
/**
* Send a user message to an AI session.
*/
export async function sendMessage(sessionId: string, message: string): Promise<void> {
await invoke("ai_send_message", { sessionId, message });
export async function sendMessage(sessionId: string, message: string, model?: ModelSelection): Promise<void> {
await invoke("ai_send_message", { sessionId, message, model });
}

/**
Expand Down
Loading
Loading