|
| 1 | +use crate::database::MathQuestion; |
| 2 | +use color_eyre::Result; |
| 3 | +use rand::Rng; |
| 4 | +use rig::{completion::Prompt, providers, client::CompletionClient}; |
| 5 | +use r2d2::Pool; |
| 6 | +use r2d2_sqlite::SqliteConnectionManager; |
| 7 | +use serde_rusqlite::from_rows; |
| 8 | + |
| 9 | +const MATH_PROMPTS: [&str; 3] = [ |
| 10 | + "Generate a simple mental math expression that can be solved in your head. Use basic operations like addition, subtraction, multiplication, or division with small numbers (prefer numbers under 20, maximum 100). Output ONLY the mathematical expression, nothing else. Examples: {ex1}, {ex2}, {ex3}, {ex4}", |
| 11 | + "Create an easy math problem with 2-4 numbers that someone can calculate mentally. Use addition, subtraction, multiplication, or simple division. Keep numbers small and friendly (single or double digits preferred). Output ONLY the expression. Examples: {ex1}, {ex2}, {ex3}, {ex4}", |
| 12 | + "Write a simple arithmetic calculation using small, friendly numbers that can be computed without a calculator. Stick to basic operations (+, -, *, /). Output ONLY the math expression. Examples: {ex1}, {ex2}, {ex3}, {ex4}" |
| 13 | +]; |
| 14 | + |
| 15 | +pub struct MathTest { |
| 16 | + pub question: String, |
| 17 | + pub answer: f64, |
| 18 | +} |
| 19 | + |
| 20 | +impl MathTest { |
| 21 | + pub async fn generate( |
| 22 | + openai_api_key: &str, |
| 23 | + db: &Pool<SqliteConnectionManager>, |
| 24 | + rng: &mut impl Rng, |
| 25 | + ) -> Result<Self> { |
| 26 | + let client = providers::openai::Client::new(openai_api_key); |
| 27 | + |
| 28 | + // Retry loop to avoid recursion |
| 29 | + for attempt in 0..5 { |
| 30 | + // Generate random example questions to show the AI |
| 31 | + // Example 1: Addition |
| 32 | + let ex1_a = rng.gen_range(5..50); |
| 33 | + let ex1_b = rng.gen_range(5..50); |
| 34 | + let ex1 = format!("{} + {}", ex1_a, ex1_b); |
| 35 | + |
| 36 | + // Example 2: Multiplication |
| 37 | + let ex2_a = rng.gen_range(3..15); |
| 38 | + let ex2_b = rng.gen_range(3..15); |
| 39 | + let ex2 = format!("{} * {}", ex2_a, ex2_b); |
| 40 | + |
| 41 | + // Example 3: Division or Subtraction |
| 42 | + let ex3 = if rng.gen_bool(0.5) { |
| 43 | + let divisor = rng.gen_range(2..13); |
| 44 | + let result = rng.gen_range(5..20); |
| 45 | + format!("{} / {}", divisor * result, divisor) |
| 46 | + } else { |
| 47 | + let ex3_a = rng.gen_range(30..100); |
| 48 | + let ex3_b = rng.gen_range(5..30); |
| 49 | + format!("{} - {}", ex3_a, ex3_b) |
| 50 | + }; |
| 51 | + |
| 52 | + // Example 4: Multi-operation or simple operation |
| 53 | + let ex4 = if rng.gen_bool(0.3) { |
| 54 | + // Multi-operation |
| 55 | + let a = rng.gen_range(3..20); |
| 56 | + let b = rng.gen_range(3..20); |
| 57 | + let c = rng.gen_range(3..20); |
| 58 | + match rng.gen_range(0..3) { |
| 59 | + 0 => format!("{} + {} + {}", a, b, c), |
| 60 | + 1 => format!("{} - {} + {}", a + b + c, b, c), |
| 61 | + _ => format!("{} + {} - {}", a, b, c), |
| 62 | + } |
| 63 | + } else { |
| 64 | + // Simple division |
| 65 | + let divisor = rng.gen_range(2..11); |
| 66 | + let result = rng.gen_range(5..15); |
| 67 | + format!("{} / {}", divisor * result, divisor) |
| 68 | + }; |
| 69 | + |
| 70 | + // Select a random prompt template and fill in the example questions |
| 71 | + let prompt_template = MATH_PROMPTS[rng.gen_range(0..MATH_PROMPTS.len())]; |
| 72 | + let prompt = prompt_template |
| 73 | + .replace("{ex1}", &ex1) |
| 74 | + .replace("{ex2}", &ex2) |
| 75 | + .replace("{ex3}", &ex3) |
| 76 | + .replace("{ex4}", &ex4); |
| 77 | + |
| 78 | + // Generate question using AI |
| 79 | + let agent = client |
| 80 | + .agent("gpt-4o") |
| 81 | + .preamble( |
| 82 | + "You are a math expression generator for mental math challenges. \ |
| 83 | + Generate simple arithmetic expressions that can be solved mentally. \ |
| 84 | + Output ONLY the mathematical expression with numbers and operators, no explanations, no greetings, no additional text." |
| 85 | + ) |
| 86 | + .temperature(0.9) |
| 87 | + .max_tokens(50) |
| 88 | + .build(); |
| 89 | + |
| 90 | + let question = match agent.prompt(&prompt).await { |
| 91 | + Ok(q) => q.trim().to_string(), |
| 92 | + Err(e) => { |
| 93 | + tracing::error!("AI generation failed on attempt {}: {:?}", attempt + 1, e); |
| 94 | + continue; |
| 95 | + } |
| 96 | + }; |
| 97 | + |
| 98 | + // Validate the expression with fasteval |
| 99 | + let mut ns = fasteval::EmptyNamespace; |
| 100 | + let answer = match fasteval::ez_eval(&question, &mut ns) { |
| 101 | + Ok(result) => result, |
| 102 | + Err(e) => { |
| 103 | + tracing::error!("Invalid math expression generated '{}': {:?}", question, e); |
| 104 | + continue; |
| 105 | + } |
| 106 | + }; |
| 107 | + |
| 108 | + // Check if this question already exists in the database |
| 109 | + let conn = db.get()?; |
| 110 | + let mut stmt = conn.prepare("SELECT * FROM math_question WHERE question = ?")?; |
| 111 | + let existing: Result<Vec<MathQuestion>, _> = from_rows(stmt.query([&question])?).collect(); |
| 112 | + |
| 113 | + // If question exists, try again |
| 114 | + if existing.is_ok() && !existing.as_ref().unwrap().is_empty() { |
| 115 | + tracing::debug!("Question '{}' already exists, retrying", question); |
| 116 | + continue; |
| 117 | + } |
| 118 | + |
| 119 | + // Store in database - use manual INSERT to let SQLite handle autoincrement |
| 120 | + conn.execute( |
| 121 | + "INSERT INTO math_question (question, answer) VALUES (?, ?)", |
| 122 | + rusqlite::params![&question, &answer], |
| 123 | + )?; |
| 124 | + |
| 125 | + return Ok(MathTest { question, answer }); |
| 126 | + } |
| 127 | + |
| 128 | + // If all attempts failed, return error |
| 129 | + Err(color_eyre::eyre::eyre!("Failed to generate valid math question after 5 attempts")) |
| 130 | + } |
| 131 | + |
| 132 | + pub fn validate_answer(&self, user_answer: &str) -> bool { |
| 133 | + // Parse user answer |
| 134 | + let user_answer = match user_answer.trim().parse::<f64>() { |
| 135 | + Ok(v) => v, |
| 136 | + Err(_) => return false, |
| 137 | + }; |
| 138 | + |
| 139 | + // Check if answer is within 0.1 tolerance (1 decimal place) |
| 140 | + (self.answer - user_answer).abs() <= 0.1 |
| 141 | + } |
| 142 | +} |
0 commit comments