Skip to content

Commit ad7c277

Browse files
committed
Fix memories
1 parent 56db2db commit ad7c277

File tree

4 files changed

+314
-1
lines changed

4 files changed

+314
-1
lines changed

src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ pub struct Config {
3434
#[arg(long, env, default_value = "tngtech/deepseek-r1t2-chimera:free")]
3535
pub openrouter_model: String,
3636
#[arg(long, env)]
37+
pub openrouter_memory_model: Option<String>,
38+
#[arg(long, env)]
3739
pub openrouter_site_url: Option<String>,
3840
#[arg(long, env)]
3941
pub openrouter_site_name: Option<String>,

src/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mod config;
4343
mod database;
4444
mod event_handler;
4545
mod math_test;
46+
mod memory_creator;
4647
mod message_handler;
4748
mod quiz_handler;
4849
mod structs;

src/memory_creator.rs

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
use color_eyre::Result;
2+
use openrouter_api::{
3+
types::chat::{ChatCompletionRequest, Message, MessageContent},
4+
OpenRouterClient,
5+
};
6+
use r2d2_sqlite::SqliteConnectionManager;
7+
use serde::{Deserialize, Serialize};
8+
use serde_rusqlite::from_row;
9+
use std::{collections::HashMap, sync::Arc};
10+
use wb_sqlite::InsertSync;
11+
12+
use crate::{
13+
config::Config,
14+
database::{Memory, User},
15+
};
16+
17+
/// JSON response structure for memory creation
18+
#[derive(Debug, Serialize, Deserialize)]
19+
struct MemoryCreationResponse {
20+
memories: Vec<MemoryEntry>,
21+
}
22+
23+
#[derive(Debug, Serialize, Deserialize)]
24+
struct MemoryEntry {
25+
username: String,
26+
key: String,
27+
content: String,
28+
}
29+
30+
/// Build the system prompt for memory creation
31+
fn build_memory_prompt(context: &str, participants: &str) -> String {
32+
format!(
33+
r#"You are a memory creation system for a Discord bot. Your job is to extract important facts, preferences, and information about users from conversations.
34+
35+
Analyze the following conversation and create memories for each participant. Focus on:
36+
- Personal preferences and interests
37+
- Facts about their life, work, or hobbies
38+
- Relationships with other users
39+
- Behaviors and patterns
40+
- Important events or milestones
41+
42+
Participants in this conversation: {participants}
43+
44+
Conversation:
45+
{context}
46+
47+
Respond ONLY with valid JSON in this exact format:
48+
{{
49+
"memories": [
50+
{{
51+
"username": "exact_username_from_conversation",
52+
"key": "category_or_topic",
53+
"content": "the actual memory content"
54+
}}
55+
]
56+
}}
57+
58+
IMPORTANT GUIDELINES:
59+
- Only create memories if there's meaningful information from THIS conversation
60+
- Each user should have AT MOST ONE memory entry per unique "key" category
61+
- The "key" should be a broad category like "preferences", "hobbies", "work", "personality", "relationships", "recent_activity"
62+
- The "content" should combine ALL related facts for that category into ONE comprehensive entry
63+
- Use exact usernames as they appear in the conversation
64+
- If there's nothing meaningful to remember, return an empty memories array
65+
66+
EXAMPLE - CORRECT (combining multiple facts under one key):
67+
{{
68+
"memories": [
69+
{{
70+
"username": "tricked.",
71+
"key": "preferences",
72+
"content": "Likes cats, dislikes insects"
73+
}}
74+
]
75+
}}
76+
77+
EXAMPLE - WRONG (duplicate keys for same user):
78+
{{
79+
"memories": [
80+
{{"username": "tricked.", "key": "preferences", "content": "Likes cats"}},
81+
{{"username": "tricked.", "key": "preferences", "content": "Dislikes insects"}}
82+
]
83+
}}
84+
85+
Remember: Output ONLY valid JSON, nothing else. Combine related information under the same key."#,
86+
context = context,
87+
participants = participants
88+
)
89+
}
90+
91+
/// Resolve usernames to user IDs using the database
92+
fn resolve_username_to_id(
93+
database: &r2d2::Pool<SqliteConnectionManager>,
94+
username: &str,
95+
) -> Option<u64> {
96+
let db = database.get().ok()?;
97+
let mut stmt = db
98+
.prepare("SELECT * FROM user WHERE name = ? COLLATE NOCASE")
99+
.ok()?;
100+
101+
stmt.query_one([username], |row| {
102+
from_row::<User>(row).map_err(|_| rusqlite::Error::QueryReturnedNoRows)
103+
})
104+
.ok()
105+
.map(|user| user.id)
106+
}
107+
108+
/// Insert a new memory in the database
109+
/// If a memory with the same user_id and key exists, delete it first then insert the new one
110+
/// This ensures the new memory has the latest timestamp and will be in the top 5 most recent
111+
fn insert_memory(
112+
database: &r2d2::Pool<SqliteConnectionManager>,
113+
user_id: u64,
114+
key: &str,
115+
content: &str,
116+
) -> Result<()> {
117+
let db = database.get()?;
118+
119+
// Delete any existing memory with the same user_id and key
120+
db.execute(
121+
"DELETE FROM memory WHERE user_id = ? AND key = ?",
122+
rusqlite::params![user_id.to_string(), key],
123+
)?;
124+
125+
// Insert the new memory (will get a new ID and timestamp)
126+
let memory = Memory {
127+
id: 0, // Will be auto-generated
128+
user_id: user_id.to_string(),
129+
content: content.to_string(),
130+
key: key.to_string(),
131+
};
132+
memory.insert_sync(&db)?;
133+
134+
Ok(())
135+
}
136+
137+
/// Process memory creation response and store in database
138+
fn process_memory_response(
139+
database: &r2d2::Pool<SqliteConnectionManager>,
140+
response_text: &str,
141+
) -> Result<usize> {
142+
log::info!("Raw memory response: {}", response_text);
143+
144+
// Try to extract JSON from the response (in case the model adds extra text)
145+
let json_start = response_text.find('{').unwrap_or(0);
146+
let json_end = response_text.rfind('}').map(|i| i + 1).unwrap_or(response_text.len());
147+
let json_text = &response_text[json_start..json_end];
148+
149+
let memory_response: MemoryCreationResponse = serde_json::from_str(json_text)
150+
.map_err(|e| color_eyre::eyre::eyre!("Failed to parse memory JSON: {} - Raw: {}", e, json_text))?;
151+
152+
let mut created_count = 0;
153+
154+
for entry in memory_response.memories {
155+
// Resolve username to user ID
156+
if let Some(user_id) = resolve_username_to_id(database, &entry.username) {
157+
match insert_memory(database, user_id, &entry.key, &entry.content) {
158+
Ok(_) => {
159+
log::info!(
160+
"Created memory for user {} ({}): {} = {}",
161+
entry.username,
162+
user_id,
163+
entry.key,
164+
entry.content
165+
);
166+
created_count += 1;
167+
}
168+
Err(e) => {
169+
log::error!("Failed to insert memory for {}: {}", entry.username, e);
170+
}
171+
}
172+
} else {
173+
log::warn!(
174+
"Could not resolve username '{}' to user ID, skipping memory",
175+
entry.username
176+
);
177+
}
178+
}
179+
180+
Ok(created_count)
181+
}
182+
183+
/// Main function to create memories in the background
184+
pub async fn create_memories_background(
185+
database: r2d2::Pool<SqliteConnectionManager>,
186+
context: String,
187+
user_mentions: HashMap<String, u64>,
188+
config: Arc<Config>,
189+
) {
190+
// Get API key
191+
let api_key = match &config.openrouter_api_key {
192+
Some(key) => key.clone(),
193+
None => {
194+
log::warn!("OpenRouter API key not configured, skipping memory creation");
195+
return;
196+
}
197+
};
198+
199+
// Determine which model to use (memory model or default to main model)
200+
let model = config
201+
.openrouter_memory_model
202+
.clone()
203+
.unwrap_or_else(|| config.openrouter_model.clone());
204+
205+
log::info!("Creating memories using model: {}", model);
206+
207+
// Create client
208+
let client = match OpenRouterClient::new()
209+
.skip_url_configuration()
210+
.with_retries(3, 1000)
211+
.with_timeout_secs(120)
212+
.configure(
213+
&api_key,
214+
config.openrouter_site_url.as_deref(),
215+
config.openrouter_site_name.as_deref(),
216+
) {
217+
Ok(c) => c,
218+
Err(e) => {
219+
log::error!("Failed to create OpenRouter client for memory creation: {}", e);
220+
return;
221+
}
222+
};
223+
224+
// Build list of participants
225+
let participants: Vec<String> = user_mentions
226+
.iter()
227+
.filter_map(|(_mention, user_id)| {
228+
let db = database.get().ok()?;
229+
let mut stmt = db.prepare("SELECT * FROM user WHERE id = ?").ok()?;
230+
stmt.query_one([user_id.to_string()], |row| {
231+
from_row::<User>(row).map_err(|_| rusqlite::Error::QueryReturnedNoRows)
232+
})
233+
.ok()
234+
.map(|user| user.name)
235+
})
236+
.collect();
237+
238+
let participants_str = participants.join(", ");
239+
240+
// Build the memory creation prompt
241+
let system_prompt = build_memory_prompt(&context, &participants_str);
242+
243+
log::debug!("Memory creation prompt: {}", system_prompt);
244+
245+
// Build request (non-streaming)
246+
let request = ChatCompletionRequest {
247+
model,
248+
messages: vec![Message {
249+
role: "user".to_string(),
250+
content: MessageContent::Text(system_prompt),
251+
..Default::default()
252+
}],
253+
max_tokens: Some(2048),
254+
stream: Some(false), // Disable streaming for memory creation
255+
..Default::default()
256+
};
257+
258+
// Create chat client
259+
let chat_client = match client.chat() {
260+
Ok(c) => c,
261+
Err(e) => {
262+
log::error!("Failed to create chat client for memory creation: {}", e);
263+
return;
264+
}
265+
};
266+
267+
// Get the complete response (non-streaming)
268+
let response = match chat_client.chat_completion(request).await {
269+
Ok(r) => r,
270+
Err(e) => {
271+
log::error!("Error getting memory creation response: {:?}", e);
272+
return;
273+
}
274+
};
275+
276+
// Extract the content from the response
277+
let response_text = match response.choices.first() {
278+
Some(choice) => match &choice.message.content {
279+
MessageContent::Text(text) => text.clone(),
280+
_ => {
281+
log::error!("No text content in memory creation response");
282+
return;
283+
}
284+
},
285+
None => {
286+
log::error!("No choices in memory creation response");
287+
return;
288+
}
289+
};
290+
291+
// Process the response
292+
match process_memory_response(&database, &response_text) {
293+
Ok(count) => {
294+
log::info!("Successfully created {} memories", count);
295+
}
296+
Err(e) => {
297+
log::error!("Failed to process memory response: {}", e);
298+
}
299+
}
300+
}

src/message_handler.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use twilight_model::{gateway::payload::incoming::MessageCreate, id::{marker::{Ch
77
use vesper::twilight_exports::UserMarker;
88

99
use crate::{
10-
ai_message, database::User, quiz_handler, structs::{Command, List, State},
10+
ai_message, database::User, memory_creator, quiz_handler, structs::{Command, List, State},
1111
utils::levels::xp_required_for_level, zalgos::zalgify_text, RESPONDERS,
1212
};
1313

@@ -283,6 +283,7 @@ pub async fn handle_message(
283283
}
284284
}
285285

286+
let user_mentions_clone = user_mentions.clone();
286287
match ai_message::main(
287288
locked_state.db.clone(),
288289
msg.author.id.get(),
@@ -301,6 +302,15 @@ pub async fn handle_message(
301302
msg.id,
302303
Arc::clone(http),
303304
));
305+
306+
// Spawn background task to create memories
307+
tokio::spawn(memory_creator::create_memories_background(
308+
locked_state.db.clone(),
309+
context.clone(),
310+
user_mentions_clone,
311+
locked_state.config.clone(),
312+
));
313+
304314
Ok(Command::nothing())
305315
}
306316
Err(e) => Ok(Command::text(format!("AI Error: {:?}", e)).reply()),

0 commit comments

Comments
 (0)