diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 6bf1320ccf..863771275c 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -2655,5 +2655,47 @@ export const webviewMessageHandler = async ( vscode.window.showWarningMessage(t("common:mdm.info.organization_requires_auth")) break } + case "fixMermaidDiagram": { + // Handle Mermaid diagram fixing request + const { code, error } = message + if (code && error) { + try { + // Import the MermaidDiagramFixer + const { MermaidDiagramFixer } = await import("../../services/mermaid/MermaidDiagramFixer") + + // Get the API configuration + const { apiConfiguration } = await provider.getState() + + // Create fixer instance with Gemini API key + const fixer = new MermaidDiagramFixer({ + geminiApiKey: apiConfiguration.geminiApiKey, + geminiModel: apiConfiguration.apiModelId, + }) + + // Fix the diagram + const fixedCode = await fixer.fixDiagram(code, error) + + // Send the fixed code back to the webview + await provider.postMessageToWebview({ + type: "mermaidDiagramFixed", + originalCode: code, + fixedCode: fixedCode, + }) + + vscode.window.showInformationMessage( + t("common:info.mermaid_diagram_fixed") || "Mermaid diagram has been fixed!", + ) + } catch (error) { + provider.log( + `Error fixing Mermaid diagram: ${JSON.stringify(error, Object.getOwnPropertyNames(error), 2)}`, + ) + vscode.window.showErrorMessage( + t("common:errors.mermaid.fix_failed") || + `Failed to fix diagram: ${error instanceof Error ? error.message : String(error)}`, + ) + } + } + break + } } } diff --git a/src/i18n/locales/en/common.json b/src/i18n/locales/en/common.json index e413bc0890..26755af7f6 100644 --- a/src/i18n/locales/en/common.json +++ b/src/i18n/locales/en/common.json @@ -90,7 +90,11 @@ "gemini": { "generate_stream": "Gemini generate context stream error: {{error}}", "generate_complete_prompt": "Gemini completion error: {{error}}", - "sources": "Sources:" + "sources": "Sources:", + "api_key_required": "Gemini API key is required for diagram fixing" + }, + "mermaid": { + "fix_failed": "Failed to fix Mermaid diagram" }, "cerebras": { "authenticationFailed": "Cerebras API authentication failed. Please check your API key is valid and not expired.", @@ -124,7 +128,8 @@ "image_copied_to_clipboard": "Image data URI copied to clipboard", "image_saved": "Image saved to {{path}}", "mode_exported": "Mode '{{mode}}' exported successfully", - "mode_imported": "Mode imported successfully" + "mode_imported": "Mode imported successfully", + "mermaid_diagram_fixed": "Mermaid diagram has been fixed successfully!" }, "answers": { "yes": "Yes", diff --git a/src/services/mermaid/MermaidDiagramFixer.ts b/src/services/mermaid/MermaidDiagramFixer.ts new file mode 100644 index 0000000000..bad976931f --- /dev/null +++ b/src/services/mermaid/MermaidDiagramFixer.ts @@ -0,0 +1,338 @@ +import { GoogleGenAI } from "@google/genai" +import { safeJsonParse } from "../../shared/safeJsonParse" +import { t } from "../../i18n" +import * as vscode from "vscode" + +// JSON Schema for structured Mermaid diagram representation +const MERMAID_JSON_SCHEMA = { + type: "object", + properties: { + diagramType: { + type: "string", + enum: ["flowchart", "sequence", "class", "state", "er", "gantt", "pie", "journey", "gitGraph", "mindmap"], + }, + title: { type: "string" }, + nodes: { + type: "array", + items: { + type: "object", + properties: { + id: { type: "string" }, + label: { type: "string" }, + shape: { type: "string" }, + style: { type: "string" }, + }, + required: ["id", "label"], + }, + }, + edges: { + type: "array", + items: { + type: "object", + properties: { + from: { type: "string" }, + to: { type: "string" }, + label: { type: "string" }, + type: { type: "string" }, + }, + required: ["from", "to"], + }, + }, + participants: { + type: "array", + items: { type: "string" }, + }, + messages: { + type: "array", + items: { + type: "object", + properties: { + from: { type: "string" }, + to: { type: "string" }, + message: { type: "string" }, + type: { type: "string" }, + }, + }, + }, + classes: { + type: "array", + items: { + type: "object", + properties: { + name: { type: "string" }, + attributes: { type: "array", items: { type: "string" } }, + methods: { type: "array", items: { type: "string" } }, + }, + }, + }, + relationships: { + type: "array", + items: { + type: "object", + properties: { + from: { type: "string" }, + to: { type: "string" }, + type: { type: "string" }, + label: { type: "string" }, + }, + }, + }, + }, + required: ["diagramType"], +} + +export interface MermaidFixerOptions { + geminiApiKey?: string + geminiModel?: string +} + +export class MermaidDiagramFixer { + private client: GoogleGenAI | null = null + private modelName: string + + constructor(private options: MermaidFixerOptions = {}) { + this.modelName = options.geminiModel || "gemini-2.0-flash-exp" + if (options.geminiApiKey) { + this.client = new GoogleGenAI({ apiKey: options.geminiApiKey }) + } + } + + /** + * Fix a Mermaid diagram with syntax errors using Gemini AI + * @param invalidCode The invalid Mermaid code + * @param errorMessage The error message from Mermaid parser + * @returns Fixed Mermaid code + * @throws Error if API key is missing or fixing fails + */ + async fixDiagram(invalidCode: string, errorMessage: string): Promise { + if (!this.client || !this.options.geminiApiKey || this.options.geminiApiKey.trim() === "") { + throw new Error("Gemini API key is required for diagram fixing") + } + + try { + // Stage 1: Generate structured JSON representation + const structuredJson = await this.generateStructuredJson(invalidCode, errorMessage) + if (!structuredJson) { + throw new Error("Failed to fix Mermaid diagram") + } + + // Stage 2: Validate JSON against schema + const validationResult = this.validateJson(structuredJson) + if (!validationResult.valid) { + console.error("JSON validation failed:", validationResult.errors) + throw new Error("Failed to fix Mermaid diagram") + } + + // Stage 3: Generate Python code to convert JSON to Mermaid + const pythonCode = await this.generatePythonConverter(structuredJson) + if (!pythonCode) { + throw new Error("Failed to fix Mermaid diagram") + } + + // Stage 4: Execute Python code to get final Mermaid DSL + const fixedMermaid = await this.executePythonCode(pythonCode, structuredJson) + if (!fixedMermaid) { + throw new Error("Failed to fix Mermaid diagram") + } + + // Post-process and validate the result + const cleanedMermaid = this.postProcessMermaid(fixedMermaid) + return cleanedMermaid + } catch (error) { + console.error("Error fixing Mermaid diagram:", error) + if (error instanceof Error && error.message === "Gemini API key is required for diagram fixing") { + throw error + } + throw new Error("Failed to fix Mermaid diagram") + } + } + + private async generateStructuredJson(invalidCode: string, errorMessage: string): Promise { + const prompt = ` +You are a Mermaid diagram expert. Analyze this invalid Mermaid code and its error message, then generate a corrected version as a structured JSON object. + +Invalid Mermaid Code: +\`\`\`mermaid +${invalidCode} +\`\`\` + +Error Message: +${errorMessage} + +Generate a JSON object that represents the corrected diagram structure. The JSON should follow this schema: +${JSON.stringify(MERMAID_JSON_SCHEMA, null, 2)} + +Rules: +1. NO parentheses in any labels or text +2. Use only alphanumeric characters, spaces, and basic punctuation in labels +3. Ensure all node IDs are unique and valid +4. Fix any syntax errors while preserving the original intent +5. Return ONLY valid JSON, no markdown or explanations + +Response:` + + try { + const model = this.client!.models.generateContent({ + model: this.modelName, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: { + temperature: 0.1, + maxOutputTokens: 2048, + }, + }) + + const result = await model + const text = result.text || "" + + // Try to parse as JSON + const json = safeJsonParse(text, null) + if (!json) { + console.error("Failed to parse JSON from Gemini response:", text) + return null + } + + return json + } catch (error) { + console.error("Error generating structured JSON:", error) + return null + } + } + + private validateJson(json: any): { valid: boolean; errors?: string[] } { + // Basic validation - check required fields and structure + const errors: string[] = [] + + if (!json.diagramType) { + errors.push("Missing required field: diagramType") + } + + // Check for parentheses in labels + const checkForParentheses = (obj: any, path: string = "") => { + if (typeof obj === "string" && (obj.includes("(") || obj.includes(")"))) { + errors.push(`Parentheses found in ${path}: "${obj}"`) + } else if (Array.isArray(obj)) { + obj.forEach((item, index) => checkForParentheses(item, `${path}[${index}]`)) + } else if (obj && typeof obj === "object") { + Object.entries(obj).forEach(([key, value]) => { + checkForParentheses(value, path ? `${path}.${key}` : key) + }) + } + } + + checkForParentheses(json) + + return { + valid: errors.length === 0, + errors: errors.length > 0 ? errors : undefined, + } + } + + private async generatePythonConverter(json: any): Promise { + const prompt = ` +Generate a Python function that converts this JSON structure into valid Mermaid DSL syntax. + +JSON Structure: +${JSON.stringify(json, null, 2)} + +Requirements: +1. The function should be named 'json_to_mermaid' +2. It should take the JSON object as input +3. It should return a string containing valid Mermaid DSL +4. Handle the specific diagram type (${json.diagramType}) +5. Ensure proper Mermaid syntax for the diagram type +6. NO parentheses in any output +7. Use proper escaping for special characters + +Return ONLY the Python code, no explanations or markdown.` + + try { + const model = this.client!.models.generateContent({ + model: this.modelName, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: { + temperature: 0.1, + maxOutputTokens: 2048, + }, + }) + + const result = await model + const code = result.text || "" + + // Clean up the code if it has markdown + const cleanCode = code + .replace(/```python\n?/g, "") + .replace(/```\n?/g, "") + .trim() + + return cleanCode + } catch (error) { + console.error("Error generating Python converter:", error) + return null + } + } + + private async executePythonCode(pythonCode: string, json: any): Promise { + // Since we can't actually execute Python in the browser/extension context, + // we'll use Gemini's code execution capability + const prompt = ` +Execute this Python code with the provided JSON input and return the output: + +Python Code: +\`\`\`python +${pythonCode} + +# Execute the function +import json +json_data = ${JSON.stringify(json)} +result = json_to_mermaid(json_data) +print(result) +\`\`\` + +Return ONLY the Mermaid DSL output, no explanations or markdown.` + + try { + const model = this.client!.models.generateContent({ + model: this.modelName, + contents: [{ role: "user", parts: [{ text: prompt }] }], + config: { + temperature: 0, + maxOutputTokens: 2048, + }, + }) + + const result = await model + const mermaidCode = result.text || "" + + // Clean up the output + const cleanMermaid = mermaidCode + .replace(/```mermaid\n?/g, "") + .replace(/```\n?/g, "") + .trim() + + return cleanMermaid + } catch (error) { + console.error("Error executing Python code:", error) + return null + } + } + + private postProcessMermaid(mermaidCode: string): string { + // Remove any remaining parentheses + let cleaned = mermaidCode.replace(/[()]/g, "") + + // Ensure proper line endings + cleaned = cleaned.replace(/\r\n/g, "\n") + + // Remove any duplicate whitespace + cleaned = cleaned.replace(/ +/g, " ") + + // Trim each line + cleaned = cleaned + .split("\n") + .map((line) => line.trim()) + .filter((line) => line.length > 0) + .join("\n") + + return cleaned + } +} diff --git a/src/services/mermaid/__tests__/MermaidDiagramFixer.spec.ts b/src/services/mermaid/__tests__/MermaidDiagramFixer.spec.ts new file mode 100644 index 0000000000..d136df0536 --- /dev/null +++ b/src/services/mermaid/__tests__/MermaidDiagramFixer.spec.ts @@ -0,0 +1,204 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { MermaidDiagramFixer } from "../MermaidDiagramFixer" + +// Mock the entire Google Generative AI module +vi.mock("@google/generative-ai", () => { + const mockGenerateContent = vi.fn() + return { + GoogleGenerativeAI: vi.fn().mockImplementation(() => ({ + getGenerativeModel: vi.fn().mockReturnValue({ + generateContent: mockGenerateContent, + }), + })), + _mockGenerateContent: mockGenerateContent, + } +}) + +describe("MermaidDiagramFixer", () => { + let fixer: MermaidDiagramFixer + let mockApiConfig: any + + beforeEach(() => { + vi.clearAllMocks() + + mockApiConfig = { + geminiApiKey: "test-api-key", + geminiModel: "gemini-1.5-flash", + } + + fixer = new MermaidDiagramFixer(mockApiConfig) + }) + + describe("fixDiagram", () => { + it("should throw error if API key is not configured", async () => { + const invalidConfig = { geminiModel: "gemini-1.5-flash" } + const fixer = new MermaidDiagramFixer(invalidConfig) + + await expect(fixer.fixDiagram("invalid diagram", "error message")).rejects.toThrow( + "Gemini API key is required for diagram fixing", + ) + }) + + it("should throw error if API key is empty string", async () => { + const invalidConfig = { geminiApiKey: "", geminiModel: "gemini-1.5-flash" } + const fixer = new MermaidDiagramFixer(invalidConfig) + + await expect(fixer.fixDiagram("invalid diagram", "error message")).rejects.toThrow( + "Gemini API key is required for diagram fixing", + ) + }) + + it("should use default model if not specified", () => { + const configWithoutModel = { geminiApiKey: "test-key" } + const fixer = new MermaidDiagramFixer(configWithoutModel) + + // The constructor should set a default model + expect(fixer).toBeDefined() + }) + + it("should handle API errors gracefully", async () => { + // Get the mock function + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockRejectedValueOnce(new Error("API Error")) + + await expect(fixer.fixDiagram("invalid diagram", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + + it("should successfully process a valid response", async () => { + // Get the mock function + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + + // Mock successful responses for each stage + _mockGenerateContent + .mockResolvedValueOnce({ + response: { + text: () => + JSON.stringify({ + diagram_type: "flowchart", + nodes: [ + { id: "A", label: "Start" }, + { id: "B", label: "End" }, + ], + edges: [{ from: "A", to: "B", label: "Process" }], + }), + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => "Valid JSON structure", + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => ` +def convert_to_mermaid(json_data): + result = "graph TD\\n" + for edge in json_data['edges']: + label = f"|{edge['label']}|" if edge.get('label') else "" + result += f"{edge['from']} -->{label} {edge['to']}\\n" + return result.strip() +`, + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => "graph TD\nA -->|Process| B", + }, + }) + + const result = await fixer.fixDiagram("graph TD\nA -> B", "Syntax error") + + expect(result).toBe("graph TD\nA -->|Process| B") + expect(_mockGenerateContent).toHaveBeenCalledTimes(4) + }) + + it("should handle empty response from API", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockResolvedValueOnce({ + response: { + text: () => "", + }, + }) + + await expect(fixer.fixDiagram("invalid", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + + it("should handle invalid JSON response", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockResolvedValueOnce({ + response: { + text: () => "Not valid JSON", + }, + }) + + await expect(fixer.fixDiagram("invalid", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + + it("should handle sequence diagrams", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + + _mockGenerateContent + .mockResolvedValueOnce({ + response: { + text: () => + JSON.stringify({ + diagram_type: "sequence", + participants: ["Alice", "Bob"], + messages: [{ from: "Alice", to: "Bob", message: "Hello", type: "solid" }], + }), + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => "Valid", + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => ` +def convert_to_mermaid(json_data): + result = "sequenceDiagram\\n" + for p in json_data['participants']: + result += f"participant {p}\\n" + for msg in json_data['messages']: + arrow = '->' if msg['type'] == 'solid' else '-->' + result += f"{msg['from']}{arrow}{msg['to']}: {msg['message']}\\n" + return result.strip() +`, + }, + }) + .mockResolvedValueOnce({ + response: { + text: () => "sequenceDiagram\nparticipant Alice\nparticipant Bob\nAlice->Bob: Hello", + }, + }) + + const result = await fixer.fixDiagram("sequenceDiagram\nAlice->Bob Hello", "Missing colon") + + expect(result).toBe("sequenceDiagram\nparticipant Alice\nparticipant Bob\nAlice->Bob: Hello") + }) + }) + + describe("error handling", () => { + it("should handle network errors", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockRejectedValueOnce(new Error("Network error")) + + await expect(fixer.fixDiagram("invalid", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + + it("should handle timeout errors", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockRejectedValueOnce(new Error("Request timeout")) + + await expect(fixer.fixDiagram("invalid", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + + it("should handle rate limit errors", async () => { + const { _mockGenerateContent } = (await import("@google/generative-ai")) as any + _mockGenerateContent.mockRejectedValueOnce(new Error("Rate limit exceeded")) + + await expect(fixer.fixDiagram("invalid", "error")).rejects.toThrow("Failed to fix Mermaid diagram") + }) + }) +}) diff --git a/src/services/mermaid/__tests__/google-generative-ai.d.ts b/src/services/mermaid/__tests__/google-generative-ai.d.ts new file mode 100644 index 0000000000..466f8fe597 --- /dev/null +++ b/src/services/mermaid/__tests__/google-generative-ai.d.ts @@ -0,0 +1,5 @@ +// Mock type declarations for @google/generative-ai in tests +declare module "@google/generative-ai" { + export const GoogleGenerativeAI: any + export const _mockGenerateContent: any +} diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 65fe181859..126e8c7c81 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -122,6 +122,7 @@ export interface ExtensionMessage { | "showEditMessageDialog" | "commands" | "insertTextIntoTextarea" + | "mermaidDiagramFixed" text?: string payload?: any // Add a generic payload for now, can refine later action?: @@ -196,6 +197,8 @@ export interface ExtensionMessage { messageTs?: number context?: string commands?: Command[] + originalCode?: string + fixedCode?: string } export type ExtensionState = Pick< diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index e2df805340..dc5e8d3268 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -212,8 +212,11 @@ export interface WebviewMessage { | "createCommand" | "insertTextIntoTextarea" | "showMdmAuthRequiredNotification" + | "fixMermaidDiagram" text?: string editedMessageContent?: string + code?: string + error?: string tab?: "settings" | "history" | "mcp" | "modes" | "chat" | "marketplace" | "account" disabled?: boolean context?: string diff --git a/webview-ui/src/components/common/MermaidBlock.tsx b/webview-ui/src/components/common/MermaidBlock.tsx index 95c795fdc5..d7736be598 100644 --- a/webview-ui/src/components/common/MermaidBlock.tsx +++ b/webview-ui/src/components/common/MermaidBlock.tsx @@ -87,11 +87,13 @@ interface MermaidBlockProps { code: string } -export default function MermaidBlock({ code }: MermaidBlockProps) { +export default function MermaidBlock({ code: initialCode }: MermaidBlockProps) { const containerRef = useRef(null) + const [code, setCode] = useState(initialCode) const [isLoading, setIsLoading] = useState(false) const [error, setError] = useState(null) const [isErrorExpanded, setIsErrorExpanded] = useState(false) + const [isFixing, setIsFixing] = useState(false) const { showCopyFeedback, copyWithFeedback } = useCopyToClipboard() const { t } = useAppTranslation() @@ -101,6 +103,23 @@ export default function MermaidBlock({ code }: MermaidBlockProps) { setError(null) }, [code]) + // Listen for fixed diagram response from extension + useEffect(() => { + const handleMessage = (event: MessageEvent) => { + const message = event.data + if (message.type === "mermaidDiagramFixed" && message.originalCode === initialCode) { + // Update the code with the fixed version + setCode(message.fixedCode) + setIsFixing(false) + setError(null) + setIsErrorExpanded(false) + } + } + + window.addEventListener("message", handleMessage) + return () => window.removeEventListener("message", handleMessage) + }, [initialCode]) + // 2) Debounce the actual parse/render useDebounceEffect( () => { @@ -153,6 +172,22 @@ export default function MermaidBlock({ code }: MermaidBlockProps) { // Copy functionality handled directly through the copyWithFeedback utility + const handleFixDiagram = async () => { + setIsFixing(true) + try { + // Send message to extension to fix the diagram + vscode.postMessage({ + type: "fixMermaidDiagram", + code: code, + error: error || undefined, + }) + } catch (err) { + console.error("Error fixing diagram:", err) + } finally { + setIsFixing(false) + } + } + return ( {isLoading && {t("common:mermaid.loading")}} @@ -188,7 +223,16 @@ export default function MermaidBlock({ code }: MermaidBlockProps) { }}> {t("common:mermaid.render_error")} -
+
+ { + e.stopPropagation() + handleFixDiagram() + }} + disabled={isFixing}> + + {t("common:mermaid.buttons.fix")} + { e.stopPropagation() @@ -309,6 +353,39 @@ const CopyButton = styled.button` } ` +const FixButton = styled.button<{ disabled?: boolean }>` + padding: 4px 8px; + height: 26px; + margin-right: 4px; + color: var(--vscode-button-foreground); + background: var(--vscode-button-background); + display: flex; + align-items: center; + justify-content: center; + border: none; + border-radius: 2px; + cursor: ${(props) => (props.disabled ? "not-allowed" : "pointer")}; + opacity: ${(props) => (props.disabled ? 0.5 : 1)}; + font-size: 12px; + + &:hover:not(:disabled) { + background: var(--vscode-button-hoverBackground); + } + + .codicon-loading { + animation: spin 1s linear infinite; + } + + @keyframes spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } + } +` + interface SvgContainerProps { $isLoading: boolean } diff --git a/webview-ui/src/i18n/locales/en/common.json b/webview-ui/src/i18n/locales/en/common.json index 973cb48297..e351771629 100644 --- a/webview-ui/src/i18n/locales/en/common.json +++ b/webview-ui/src/i18n/locales/en/common.json @@ -34,7 +34,8 @@ "save": "Save Image", "viewCode": "View Code", "viewDiagram": "View Diagram", - "close": "Close" + "close": "Close", + "fix": "Fix Diagram" }, "modal": { "codeTitle": "Mermaid Code"