diff --git a/README.md b/README.md index ae04239ae6..bb10925fac 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,18 @@ Codex CLI supports [MCP servers](./docs/advanced.md#model-context-protocol-mcp). Codex CLI supports a rich set of configuration options, with preferences stored in `~/.codex/config.toml`. For full configuration options, see [Configuration](./docs/config.md). +### Guardrails (Optional) + +Codex can optionally preload Markdown guardrails from a `.guardrails/` directory in your project. Enable guardrails with the CLI flag or environment variable: + +```bash +codex --guardrails +# or +CODEX_GUARDRAILS=1 codex +``` + +When guardrails are enabled, the Markdown files are concatenated and prepended to your initial prompt, ensuring Codex receives important project guidelines up front. + --- ### Docs & FAQ diff --git a/__tests__/guardrails.test.ts b/__tests__/guardrails.test.ts new file mode 100644 index 0000000000..3cf2f11bf5 --- /dev/null +++ b/__tests__/guardrails.test.ts @@ -0,0 +1,71 @@ +import assert from "node:assert/strict"; +import { mkdtemp, copyFile, chmod } from "node:fs/promises"; +import path from "node:path"; +import { tmpdir } from "node:os"; +import test from "node:test"; + +import { buildPromptWithGuardrails } from "../src/cli.ts"; +import { loadGuardrails } from "../src/extensions/guardrails.ts"; + +async function createTempProjectWithBridge() { + const dir = await mkdtemp(path.join(tmpdir(), "codex-guardrails-")); + const bridgeSourcePath = path.resolve("guardloop_bridge.py"); + const bridgeDestPath = path.join(dir, "guardloop_bridge.py"); + await copyFile(bridgeSourcePath, bridgeDestPath); + await chmod(bridgeDestPath, "755"); + return { dir }; +} + +test("loadGuardrails returns code guardrails for a code prompt", async () => { + const { dir } = await createTempProjectWithBridge(); + const codePrompt = "implement a function to sort a list"; + const guardrails = await loadGuardrails({ cwd: dir, prompt: codePrompt }); + + const expectedGuardrails = [ + "## Guardrail: Code Standard", + "- All functions must have a docstring.", + "- Wrap async database calls in try-catch blocks.", + ].join("\n"); + + assert.strictEqual(guardrails, expectedGuardrails); +}); + +test("loadGuardrails returns no guardrails for a creative prompt", async () => { + const { dir } = await createTempProjectWithBridge(); + const creativePrompt = "write a blog post about AI"; + const guardrails = await loadGuardrails({ cwd: dir, prompt: creativePrompt }); + + assert.strictEqual(guardrails, ""); +}); + +test("buildPromptWithGuardrails prepends guardrails correctly for a code prompt", async () => { + const { dir } = await createTempProjectWithBridge(); + const userPrompt = "implement user authentication service"; + + // We need to call the function that is actually exported and used. + // The options object for buildPromptWithGuardrails needs to be constructed correctly. + const result = await buildPromptWithGuardrails(userPrompt, { + cwd: dir, + guardrailsEnabled: true, + }); + + const expectedGuardrails = [ + "## Guardrail: Code Standard", + "- All functions must have a docstring.", + "- Wrap async database calls in try-catch blocks.", + ].join("\n"); + + assert.strictEqual(result, `${expectedGuardrails}\n\n${userPrompt}`); +}); + +test("buildPromptWithGuardrails returns only the prompt for a creative prompt", async () => { + const { dir } = await createTempProjectWithBridge(); + const userPrompt = "write a poem about the sea"; + + const result = await buildPromptWithGuardrails(userPrompt, { + cwd: dir, + guardrailsEnabled: true, + }); + + assert.strictEqual(result, userPrompt); +}); \ No newline at end of file diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 9c30afdb4b..ff4bc48a87 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -363,6 +363,36 @@ impl App { AppEvent::OpenReviewCustomPrompt => { self.chat_widget.show_review_custom_prompt(); } + AppEvent::LogGuardLoopFailure(prompt) => { + let bridge_path = self.config.cwd.join("guardloop_bridge.py"); + let output = std::process::Command::new(bridge_path) + .arg("--log-failure") + .arg(prompt) + .output(); + + match output { + Ok(output) => { + if output.status.success() { + self.chat_widget.add_info_message( + "Feedback logged. Thank you!".to_string(), + None, + ); + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + self.chat_widget.add_error_message(format!( + "Failed to log feedback: {}", + stderr + )); + } + } + Err(e) => { + self.chat_widget.add_error_message(format!( + "Failed to execute GuardLoop bridge: {}", + e + )); + } + } + } } Ok(true) } diff --git a/codex-rs/tui/src/app_event.rs b/codex-rs/tui/src/app_event.rs index 56c66379a6..7bdf127232 100644 --- a/codex-rs/tui/src/app_event.rs +++ b/codex-rs/tui/src/app_event.rs @@ -76,4 +76,7 @@ pub(crate) enum AppEvent { /// Open the custom prompt option from the review popup. OpenReviewCustomPrompt, + + /// Log a failed interaction with the GuardLoop bridge. + LogGuardLoopFailure(String), } diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index 724feb29f2..3c265a25cf 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -225,6 +225,7 @@ pub(crate) struct ChatWidget { auth_manager: Arc, session_header: SessionHeader, initial_user_message: Option, + last_user_message: Option, token_info: Option, rate_limit_snapshot: Option, rate_limit_warnings: RateLimitWarningState, @@ -258,6 +259,7 @@ pub(crate) struct ChatWidget { needs_final_message_separator: bool, } +#[derive(Clone)] struct UserMessage { text: String, image_paths: Vec, @@ -385,6 +387,8 @@ impl ChatWidget { self.running_commands.clear(); self.request_redraw(); + self.show_feedback_prompt(); + // If there is a queued user message, send exactly one now to begin the next turn. self.maybe_send_next_queued_input(); // Emit a notification when the turn completes (suppressed if focused). @@ -895,6 +899,7 @@ impl ChatWidget { initial_prompt.unwrap_or_default(), initial_images, ), + last_user_message: None, token_info: None, rate_limit_snapshot: None, rate_limit_warnings: RateLimitWarningState::default(), @@ -957,6 +962,7 @@ impl ChatWidget { initial_prompt.unwrap_or_default(), initial_images, ), + last_user_message: None, token_info: None, rate_limit_snapshot: None, rate_limit_warnings: RateLimitWarningState::default(), @@ -1220,6 +1226,7 @@ impl ChatWidget { } fn submit_user_message(&mut self, user_message: UserMessage) { + self.last_user_message = Some(user_message.clone()); let UserMessage { text, image_paths } = user_message; if text.is_empty() && image_paths.is_empty() { return; @@ -1982,6 +1989,48 @@ impl ChatWidget { let [_, _, bottom_pane_area] = self.layout_areas(area); self.bottom_pane.cursor_pos(bottom_pane_area) } + + pub(crate) fn show_feedback_prompt(&mut self) { + let last_prompt = if let Some(msg) = &self.last_user_message { + msg.text.clone() + } else { + return; // No last prompt, nothing to do. + }; + + if last_prompt.is_empty() { + return; // Don't ask for feedback on empty prompts + } + + let mut items = Vec::new(); + + // "Yes" item + items.push(SelectionItem { + name: "Yes".to_string(), + description: None, + is_current: false, + actions: Vec::new(), // No action needed, just dismiss. + dismiss_on_select: true, + search_value: None, + }); + + // "No" item + items.push(SelectionItem { + name: "No".to_string(), + description: Some("Log this interaction as a failure".to_string()), + is_current: false, + actions: vec![Box::new(move |tx: &AppEventSender| { + tx.send(AppEvent::LogGuardLoopFailure(last_prompt.clone())); + })], + dismiss_on_select: true, + search_value: None, + }); + + self.bottom_pane.show_selection_view(SelectionViewParams { + title: "Was this helpful?".to_string(), + items, + ..Default::default() + }); + } } impl WidgetRef for &ChatWidget { diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index f3799ad336..037057f676 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -322,6 +322,7 @@ fn make_chatwidget_manual() -> ( auth_manager, session_header: SessionHeader::new(cfg.model), initial_user_message: None, + last_user_message: None, token_info: None, rate_limit_snapshot: None, rate_limit_warnings: RateLimitWarningState::default(), diff --git a/guardloop_bridge.py b/guardloop_bridge.py new file mode 100755 index 0000000000..9704e967c9 --- /dev/null +++ b/guardloop_bridge.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + +import sys +import json +import sqlite3 +import os + +DB_PATH = os.path.expanduser("~/.guardloop/data.db") + +def init_database(): + """ + Initializes the SQLite database and the 'failures' table if they don't exist. + """ + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS failures ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + prompt TEXT NOT NULL, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + conn.close() + +def log_failure(prompt): + """ + Logs a failed prompt to the database. + """ + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + cursor.execute("INSERT INTO failures (prompt) VALUES (?)", (prompt,)) + conn.commit() + conn.close() + +def classify_task(prompt): + """ + Classifies the task based on the prompt. + """ + code_keywords = ["implement", "function", "class", "debug", "test", "fix"] + creative_keywords = ["write", "blog post", "email", "poem", "summarize"] + + prompt_lower = prompt.lower() + if any(keyword in prompt_lower for keyword in code_keywords): + return {"classification": "code", "confidence": 0.95} + if any(keyword in prompt_lower for keyword in creative_keywords): + return {"classification": "creative", "confidence": 0.92} + return {"classification": "unknown", "confidence": 0.5} + +def get_guardrails_for_task(classification): + """ + Returns guardrails based on the task classification. + """ + if classification["classification"] == "code": + return [ + "## Guardrail: Code Standard", + "- All functions must have a docstring.", + "- Wrap async database calls in try-catch blocks." + ] + return [] + +def main(): + """ + Main function to handle commands for logging failures or getting guardrails. + """ + init_database() + + if "--log-failure" in sys.argv: + try: + prompt_index = sys.argv.index("--log-failure") + 1 + prompt = sys.argv[prompt_index] + log_failure(prompt) + print(json.dumps({"status": "failure logged"})) + except (ValueError, IndexError): + print(json.dumps({"error": "No prompt provided for failure logging."}), file=sys.stderr) + sys.exit(1) + return + + prompt = sys.argv[1] if len(sys.argv) > 1 else "" + if not prompt: + sys.exit(0) + + classification = classify_task(prompt) + guardrails = get_guardrails_for_task(classification) + + output = { + "classification": classification, + "guardrails": "\n".join(guardrails) + } + print(json.dumps(output)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/package.json b/package.json index 606c323ff4..779c7b9ec3 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,8 @@ "description": "Tools for repo-wide maintenance.", "scripts": { "format": "prettier --check *.json *.md .github/workflows/*.yml **/*.js", - "format:fix": "prettier --write *.json *.md .github/workflows/*.yml **/*.js" + "format:fix": "prettier --write *.json *.md .github/workflows/*.yml **/*.js", + "test": "node --loader ./ts-loader.mjs --test __tests__/guardrails.test.ts" }, "devDependencies": { "prettier": "^3.5.3" diff --git a/src/cli.ts b/src/cli.ts new file mode 100644 index 0000000000..96a05eaeda --- /dev/null +++ b/src/cli.ts @@ -0,0 +1,74 @@ +import { loadGuardrails } from "./extensions/guardrails.ts"; + +export const GUARDRAILS_ENV_VAR = "CODEX_GUARDRAILS"; + +const GUARDRAIL_FLAG = "--guardrails"; +const GUARDRAIL_FLAG_ALIAS = "-g"; +const GUARDRAIL_FLAG_DISABLE = "--no-guardrails"; + +function parseBoolean(value) { + if (value == null) { + return null; + } + + const normalized = String(value).trim().toLowerCase(); + if (["1", "true", "yes", "on"].includes(normalized)) { + return true; + } + + if (["0", "false", "no", "off"].includes(normalized)) { + return false; + } + + return null; +} + +function parseGuardrailFlag(argv) { + for (const arg of argv) { + if (arg === GUARDRAIL_FLAG || arg === GUARDRAIL_FLAG_ALIAS) { + return true; + } + + if (arg === GUARDRAIL_FLAG_DISABLE) { + return false; + } + + if (arg.startsWith("--guardrails=")) { + const [, rawValue] = arg.split("=", 2); + const parsed = parseBoolean(rawValue); + if (parsed !== null) { + return parsed; + } + return rawValue !== ""; + } + } + + return null; +} + +export function shouldUseGuardrails({ argv = process.argv.slice(2), env = process.env } = {}) { + const cliPreference = parseGuardrailFlag(argv); + if (cliPreference !== null) { + return cliPreference; + } + + const envPreference = parseBoolean(env[GUARDRAILS_ENV_VAR]); + return envPreference === null ? false : envPreference; +} + +export async function buildPromptWithGuardrails(userPrompt, options = {}) { + const { argv = process.argv.slice(2), env = process.env, cwd = process.cwd(), guardrailsEnabled } = options; + const enabled = + typeof guardrailsEnabled === "boolean" ? guardrailsEnabled : shouldUseGuardrails({ argv, env }); + + if (!enabled) { + return userPrompt; + } + + const guardrails = await loadGuardrails({ cwd, prompt: userPrompt }); + if (!guardrails) { + return userPrompt; + } + + return `${guardrails}\n\n${userPrompt}`; +} diff --git a/src/extensions/guardrails.ts b/src/extensions/guardrails.ts new file mode 100644 index 0000000000..8a0638f740 --- /dev/null +++ b/src/extensions/guardrails.ts @@ -0,0 +1,28 @@ +import path from "node:path"; +import { execFileSync } from "node:child_process"; + +/** + * Load guardrails from the GuardLoop bridge. + * + * @param {{ cwd?: string, prompt?: string }} [options] + * @returns {Promise} Combined guardrail contents. + */ +export async function loadGuardrails(options) { + const { cwd = process.cwd(), prompt = "" } = options || {}; + const bridgePath = path.join(cwd, "guardloop_bridge.py"); + + try { + // Execute the python script with the prompt as an argument. + const output = execFileSync(bridgePath, [prompt], { encoding: "utf-8" }); + const result = JSON.parse(output); + return result.guardrails.trim(); + } catch (error) { + // If the script fails, log the error and return no guardrails. + if (error instanceof Error) { + console.error("Failed to execute GuardLoop bridge:", error.message); + } else { + console.error("Failed to execute GuardLoop bridge:", String(error)); + } + return ""; + } +} diff --git a/ts-loader.mjs b/ts-loader.mjs new file mode 100644 index 0000000000..eebd16f672 --- /dev/null +++ b/ts-loader.mjs @@ -0,0 +1,20 @@ +import { readFile } from "node:fs/promises"; +import { pathToFileURL } from "node:url"; + +export async function resolve(specifier, context, defaultResolve) { + if (specifier.endsWith(".ts")) { + const url = new URL(specifier, context.parentURL || pathToFileURL(process.cwd() + "/")); + return { url: url.href, format: "module", shortCircuit: true }; + } + + return defaultResolve(specifier, context, defaultResolve); +} + +export async function load(url, context, defaultLoad) { + if (url.endsWith(".ts")) { + const source = await readFile(new URL(url), "utf8"); + return { format: "module", source, shortCircuit: true }; + } + + return defaultLoad(url, context, defaultLoad); +}