diff --git a/src/commands/aicommits.ts b/src/commands/aicommits.ts index e26372ac..739509a9 100644 --- a/src/commands/aicommits.ts +++ b/src/commands/aicommits.ts @@ -23,7 +23,7 @@ import { } from '../utils/openai.js'; import { KnownError, handleCommandError } from '../utils/error.js'; -import { getCommitMessage } from '../utils/commit-helpers.js'; +import { getCommitMessage, type CommitMessageResult } from '../utils/commit-helpers.js'; export default async ( generate: number | undefined, @@ -108,141 +108,175 @@ export default async ( // Check if diff is large and needs chunking const MAX_FILES = 50; const CHUNK_SIZE = 10; - let isChunking = false; - if (staged.files.length > MAX_FILES) { - isChunking = true; + const isChunking = staged.files.length > MAX_FILES; + + const baseUrl = providerInstance.getBaseUrl(); + const apiKey = providerInstance.getApiKey() || ''; + + // Truncate diff if too large to avoid context limits + const maxDiffLength = 30000; // Approximate 7.5k tokens + let diffToUse = staged.diff; + if (diffToUse.length > maxDiffLength) { + diffToUse = + diffToUse.substring(diffToUse.length - maxDiffLength) + + '\n\n[Diff truncated due to size]'; } - const s = spinner(); - s.start( - `🔍 Analyzing changes in ${staged.files.length} file${ - staged.files.length === 1 ? '' : 's' - }` - ); - const startTime = Date.now(); - let messages: string[]; - let usage: any; - try { - const baseUrl = providerInstance.getBaseUrl(); - const apiKey = providerInstance.getApiKey() || ''; - - if (isChunking) { - // Split files into chunks - const chunks: string[][] = []; - for (let i = 0; i < staged.files.length; i += CHUNK_SIZE) { - chunks.push(staged.files.slice(i, i + CHUNK_SIZE)); - } + // Helper function to generate messages (supports regeneration) + const generateMessages = async (regenerateOptions?: { + previousMessage: string; + userContext?: string; + }) => { + const s = spinner(); + const actionText = regenerateOptions ? '🔄 Regenerating' : '🔍 Analyzing'; + s.start( + `${actionText} changes in ${staged.files.length} file${ + staged.files.length === 1 ? '' : 's' + }` + ); + const startTime = Date.now(); + let messages: string[]; + let usage: any; - const chunkMessages: string[] = []; - let totalUsage = { - promptTokens: 0, - completionTokens: 0, - totalTokens: 0, - }; - - for (const chunk of chunks) { - const chunkDiff = await getStagedDiffForFiles(chunk, excludeFiles); - if (chunkDiff && chunkDiff.diff) { - // Truncate diff if too large to avoid context limits - const maxDiffLength = 30000; // Approximate 7.5k tokens - let diffToUse = chunkDiff.diff; - if (diffToUse.length > maxDiffLength) { - diffToUse = - diffToUse.substring(diffToUse.length - maxDiffLength) + - '\n\n[Diff truncated due to size]'; - } - const result = await generateCommitMessage( - baseUrl, - apiKey, - config.model!, - config.locale, - diffToUse, - config.generate, - config['max-length'], - config.type, - timeout - ); - chunkMessages.push(...result.messages); - if (result.usage) { - totalUsage.promptTokens += - (result.usage as any).promptTokens || 0; - totalUsage.completionTokens += - (result.usage as any).completionTokens || 0; - totalUsage.totalTokens += (result.usage as any).totalTokens || 0; + try { + if (isChunking) { + // Split files into chunks + const chunks: string[][] = []; + for (let i = 0; i < staged.files.length; i += CHUNK_SIZE) { + chunks.push(staged.files.slice(i, i + CHUNK_SIZE)); + } + + const chunkMessages: string[] = []; + let totalUsage = { + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + }; + + for (const chunk of chunks) { + const chunkDiff = await getStagedDiffForFiles(chunk, excludeFiles); + if (chunkDiff && chunkDiff.diff) { + let chunkDiffToUse = chunkDiff.diff; + if (chunkDiffToUse.length > maxDiffLength) { + chunkDiffToUse = + chunkDiffToUse.substring(chunkDiffToUse.length - maxDiffLength) + + '\n\n[Diff truncated due to size]'; + } + const result = await generateCommitMessage( + baseUrl, + apiKey, + config.model!, + config.locale, + chunkDiffToUse, + config.generate, + config['max-length'], + config.type, + timeout, + regenerateOptions + ); + chunkMessages.push(...result.messages); + if (result.usage) { + totalUsage.promptTokens += + (result.usage as any).promptTokens || 0; + totalUsage.completionTokens += + (result.usage as any).completionTokens || 0; + totalUsage.totalTokens += (result.usage as any).totalTokens || 0; + } } } - } - // Combine the chunk messages - const combineResult = await combineCommitMessages( - chunkMessages, - baseUrl, - apiKey, - config.model!, - config.locale, - config['max-length'], - config.type, - timeout - ); - messages = combineResult.messages; - if (combineResult.usage) { - totalUsage.promptTokens += - (combineResult.usage as any).promptTokens || 0; - totalUsage.completionTokens += - (combineResult.usage as any).completionTokens || 0; - totalUsage.totalTokens += - (combineResult.usage as any).totalTokens || 0; + // Combine the chunk messages + const combineResult = await combineCommitMessages( + chunkMessages, + baseUrl, + apiKey, + config.model!, + config.locale, + config['max-length'], + config.type, + timeout + ); + messages = combineResult.messages; + if (combineResult.usage) { + totalUsage.promptTokens += + (combineResult.usage as any).promptTokens || 0; + totalUsage.completionTokens += + (combineResult.usage as any).completionTokens || 0; + totalUsage.totalTokens += + (combineResult.usage as any).totalTokens || 0; + } + usage = totalUsage; + } else { + const result = await generateCommitMessage( + baseUrl, + apiKey, + config.model!, + config.locale, + diffToUse, + config.generate, + config['max-length'], + config.type, + timeout, + regenerateOptions + ); + messages = result.messages; + usage = result.usage; } - usage = totalUsage; - } else { - // Truncate diff if too large to avoid context limits - const maxDiffLength = 30000; // Approximate 7.5k tokens - let diffToUse = staged.diff; - if (diffToUse.length > maxDiffLength) { - diffToUse = - diffToUse.substring(diffToUse.length - maxDiffLength) + - '\n\n[Diff truncated due to size]'; + + return { messages, usage }; + } finally { + const duration = Date.now() - startTime; + let tokensStr = ''; + if (usage?.total_tokens) { + const tokens = usage.total_tokens; + const formattedTokens = + tokens >= 1000 ? `${(tokens / 1000).toFixed(0)}k` : tokens.toString(); + const speed = Math.round(tokens / (duration / 1000)); + tokensStr = `, ${formattedTokens} tokens (${speed} tokens/s)`; } - const result = await generateCommitMessage( - baseUrl, - apiKey, - config.model!, - config.locale, - diffToUse, - config.generate, - config['max-length'], - config.type, - timeout + const doneText = regenerateOptions ? '✅ Regenerated' : '✅ Changes analyzed'; + s.stop( + `${doneText} in ${(duration / 1000).toFixed(1)}s${tokensStr}` ); - messages = result.messages; - usage = result.usage; - } - } finally { - const duration = Date.now() - startTime; - let tokensStr = ''; - if (usage?.total_tokens) { - const tokens = usage.total_tokens; - const formattedTokens = - tokens >= 1000 ? `${(tokens / 1000).toFixed(0)}k` : tokens.toString(); - const speed = Math.round(tokens / (duration / 1000)); - tokensStr = `, ${formattedTokens} tokens (${speed} tokens/s)`; } - s.stop( - `✅ Changes analyzed in ${(duration / 1000).toFixed(1)}s${tokensStr}` - ); - } + }; + + // Initial generation + let { messages } = await generateMessages(); if (messages.length === 0) { throw new KnownError('No commit messages were generated. Try again.'); } - // Get the commit message - const message = await getCommitMessage(messages, skipConfirm); - if (!message) { - outro('Commit cancelled'); - return; + // Message selection loop (supports regeneration) + let result: CommitMessageResult; + while (true) { + result = await getCommitMessage(messages, skipConfirm); + + if (result.action === 'cancel') { + outro('Commit cancelled'); + return; + } + + if (result.action === 'confirm') { + break; + } + + // Regenerate + const previousMessage = messages[0]; // Use first message as reference + const regenerated = await generateMessages({ + previousMessage, + userContext: result.context, + }); + messages = regenerated.messages; + + if (messages.length === 0) { + throw new KnownError('No commit messages were generated. Try again.'); + } } + const message = result.message; + // Handle clipboard mode (early return) if (copyToClipboard) { const success = await copyMessage(message); diff --git a/src/utils/commit-helpers.ts b/src/utils/commit-helpers.ts index a652c167..6a2783db 100644 --- a/src/utils/commit-helpers.ts +++ b/src/utils/commit-helpers.ts @@ -1,5 +1,10 @@ import { KnownError } from './error.js'; +export type CommitMessageResult = + | { action: 'confirm'; message: string } + | { action: 'cancel' } + | { action: 'regenerate'; context?: string }; + export const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); export const retry = async (fn: () => Promise, attempts: number = 3, delay: number = 1000): Promise => { @@ -17,8 +22,8 @@ export const retry = async (fn: () => Promise, attempts: number = 3, delay export const getCommitMessage = async ( messages: string[], skipConfirm: boolean -): Promise => { - const { select, confirm, isCancel } = await import('@clack/prompts'); +): Promise => { + const { select, text, isCancel } = await import('@clack/prompts'); const { dim } = await import('kolorist'); // Check if interactive prompts are available @@ -29,7 +34,7 @@ export const getCommitMessage = async ( const [message] = messages; if (skipConfirm) { - return message; + return { action: 'confirm', message }; } if (!isInteractive) { @@ -37,16 +42,41 @@ export const getCommitMessage = async ( } console.log(`\n\x1b[1m${message}\x1b[0m\n`); - const confirmed = await confirm({ - message: 'Use this commit message?', + const selected = await select({ + message: 'What would you like to do?', + options: [ + { label: 'Use this commit message', value: 'confirm' }, + { label: `Regenerate ${dim('(r)')}`, value: 'regenerate' }, + { label: 'Cancel', value: 'cancel' }, + ], }); - return confirmed && !isCancel(confirmed) ? message : null; + if (isCancel(selected) || selected === 'cancel') { + return { action: 'cancel' }; + } + + if (selected === 'regenerate') { + const context = await text({ + message: `Add context for regeneration ${dim('(optional, press Enter to skip)')}:`, + placeholder: 'e.g., "focus on the bug fix" or "be more specific"', + }); + + if (isCancel(context)) { + return { action: 'cancel' }; + } + + return { + action: 'regenerate', + context: context && typeof context === 'string' ? context.trim() || undefined : undefined, + }; + } + + return { action: 'confirm', message }; } // Multiple messages case if (skipConfirm) { - return messages[0]; + return { action: 'confirm', message: messages[0] }; } if (!isInteractive) { @@ -55,8 +85,33 @@ export const getCommitMessage = async ( const selected = await select({ message: `Pick a commit message to use: ${dim('(Ctrl+c to exit)')}`, - options: messages.map((value) => ({ label: value, value })), + options: [ + ...messages.map((value) => ({ label: value, value })), + { label: dim('─────────────────────'), value: 'separator', disabled: true } as any, + { label: `Regenerate all ${dim('(r)')}`, value: 'regenerate' }, + { label: 'Cancel', value: 'cancel' }, + ], }); - return isCancel(selected) ? null : (selected as string); + if (isCancel(selected) || selected === 'cancel') { + return { action: 'cancel' }; + } + + if (selected === 'regenerate') { + const context = await text({ + message: `Add context for regeneration ${dim('(optional, press Enter to skip)')}:`, + placeholder: 'e.g., "focus on the bug fix" or "be more specific"', + }); + + if (isCancel(context)) { + return { action: 'cancel' }; + } + + return { + action: 'regenerate', + context: context && typeof context === 'string' ? context.trim() || undefined : undefined, + }; + } + + return { action: 'confirm', message: selected as string }; }; \ No newline at end of file diff --git a/src/utils/openai.ts b/src/utils/openai.ts index 4e46f6d8..c31eaa83 100644 --- a/src/utils/openai.ts +++ b/src/utils/openai.ts @@ -77,7 +77,11 @@ export const generateCommitMessage = async ( completions: number, maxLength: number, type: CommitType, - timeout: number + timeout: number, + regenerateOptions?: { + previousMessage: string; + userContext?: string; + } ) => { if (process.env.DEBUG) { console.log('Diff being sent to AI:'); @@ -100,9 +104,9 @@ export const generateCommitMessage = async ( const promises = Array.from({ length: completions }, () => generateText({ model: provider(model), - system: generatePrompt(locale, maxLength, type), + system: generatePrompt(locale, maxLength, type, regenerateOptions), prompt: diff, - temperature: 0.4, + temperature: regenerateOptions ? 0.7 : 0.4, // Higher temperature for more variation when regenerating maxRetries: 2, maxOutputTokens: 2000, }).finally(() => clearTimeout(timeoutId)) diff --git a/src/utils/prompt.ts b/src/utils/prompt.ts index 77874eb6..484970c9 100644 --- a/src/utils/prompt.ts +++ b/src/utils/prompt.ts @@ -126,9 +126,13 @@ const commitTypes: Record = { export const generatePrompt = ( locale: string, maxLength: number, - type: CommitType -) => - [ + type: CommitType, + regenerateOptions?: { + previousMessage: string; + userContext?: string; + } +) => { + const basePrompt = [ 'Generate a concise git commit message title in present tense that precisely describes the key changes in the following code diff. Focus on what was changed, not just file names. Provide only the title, no description or body.', `Message language: ${locale}`, `Commit message must be a maximum of ${maxLength} characters.`, @@ -137,6 +141,19 @@ export const generatePrompt = ( 'Be specific: include concrete details (package names, versions, functionality) rather than generic statements.', commitTypes[type], specifyCommitFormat(type), - ] - .filter(Boolean) - .join('\n'); + ]; + + if (regenerateOptions) { + basePrompt.push( + '', + 'REGENERATION REQUEST:', + `The previous commit message was: "${regenerateOptions.previousMessage}"`, + 'Generate a meaningfully DIFFERENT commit message. Use different wording, emphasis, or focus while still accurately describing the changes.', + regenerateOptions.userContext + ? `User guidance: ${regenerateOptions.userContext}` + : '' + ); + } + + return basePrompt.filter(Boolean).join('\n'); +};