diff --git a/backend/server.py b/backend/server.py index a84e64f..9743156 100644 --- a/backend/server.py +++ b/backend/server.py @@ -43,6 +43,7 @@ TORCH_MLIR_OPT_PATH = os.environ.get("TORCH_MLIR_OPT_PATH", "") LLVM_BIN_PATH = os.environ.get("LLVM_BIN_PATH", "") +TRITON_OPT_PATH = os.environ.get("TRITON_OPT_PATH", "") class CodeRequest(BaseModel): @@ -55,6 +56,8 @@ class CodeRequest(BaseModel): mlir_translate: Optional[str] = "" llvm_opt: Optional[str] = "" llc: Optional[str] = "" + triton_opt: Optional[str] = "" + triton_llvm_opt: Optional[str] = "" user_tool: Optional[str] = "" dump_after_each_opt: Optional[bool] = False @@ -156,6 +159,10 @@ def apply_optional_passes( tool_path = LLVM_BIN_PATH + "opt" elif tool == "llc": tool_path = LLVM_BIN_PATH + "llc" + elif tool == "triton-opt": + tool_path = TRITON_OPT_PATH + "triton-opt" + elif tool == "triton-llvm-opt": + tool_path = TRITON_OPT_PATH + "triton-llvm-opt" elif tool == "user-tool": tokens = flags.strip().split() if not tokens: @@ -341,7 +348,9 @@ def generate_llvm_ir( # TODO: Figure out static compilation. -def compile_triton_ir(code: str, ir_type: str) -> str: +def compile_triton_ir( + code: str, ir_type: str, pipeline: List[Tuple[str, str]], dump_each: bool +) -> str: try: code_hash = hash_code(code) cache_info = cached_triton_runs.get(code_hash) @@ -403,7 +412,8 @@ def compile_triton_ir(code: str, ir_type: str) -> str: cached_triton_runs[code_hash]["active_users"] += 1 - return "\n\n".join(output_parts) + ir_dump = "\n\n".join(output_parts) + return apply_optional_passes(ir_dump, pipeline, dump_each) except Exception as e: return f"Error compiling Triton IR: {str(e)}" @@ -427,6 +437,12 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]: if request.llc: for stage in request.llc.split("&&"): pipeline.append(("llc", stage.strip())) + if request.llc: + for stage in request.triton_opt.split("&&"): + pipeline.append(("triton_opt", stage.strip())) + if request.llc: + for stage in request.triton_llvm_opt.split("&&"): + pipeline.append(("triton_llvm_opt", stage.strip())) if request.user_tool: for stage in request.user_tool.split("&&"): pipeline.append(("user-tool", stage.strip())) @@ -437,7 +453,10 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]: def process_model(request: CodeRequest) -> str: try: if request.ir_type.startswith("triton"): - return compile_triton_ir(request.code, request.ir_type) + pipeline = build_pipeline(request) + return compile_triton_ir( + request.code, request.ir_type, pipeline, request.dump_after_each_opt + ) if request.ir_type == "raw_ir" and request.selected_language == "pytorch": # Execute user Python, capture stdout. diff --git a/src/app/page.js b/src/app/page.js index 6b43532..8f91e25 100644 --- a/src/app/page.js +++ b/src/app/page.js @@ -84,6 +84,9 @@ const pytorchIROptions = [ const tritonIROptions = [ { value: "triton_ir", label: "Triton IR" }, { value: "triton_gpu_ir", label: "Triton GPU IR" }, + // FIXME: 'Triton' LLVM IR is basically LLVM IR. Right now it's separated + // due to the fact, that all of Triron IRs are just dumps from the cache dir + // after model execution. { value: "triton_llvm_ir", label: "LLVM IR" }, { value: "triton_nvptx", label: "NVPTX" }, ]; @@ -168,8 +171,10 @@ export default function PyTorchTritonExplorer() { if (flags !== null) { setIrWindows((prev) => prev.map((w) => - w.id === id ? { ...w, pipeline: [...w.pipeline, { tool, flags }] } : w - ) + w.id === id + ? { ...w, pipeline: [...w.pipeline, { tool, flags }] } + : w, + ), ); } }; @@ -185,8 +190,8 @@ export default function PyTorchTritonExplorer() { { tool: currentTool, flags: currentFlags }, ], } - : w - ) + : w, + ), ); setModalVisible(false); }; @@ -206,11 +211,11 @@ export default function PyTorchTritonExplorer() { ? { ...w, pipeline: w.pipeline.map((p, i) => - i === editPassIndex ? { tool: editTool, flags: editFlags } : p + i === editPassIndex ? { tool: editTool, flags: editFlags } : p, ), } - : w - ) + : w, + ), ); setEditModalVisible(false); }; @@ -223,8 +228,8 @@ export default function PyTorchTritonExplorer() { ...w, pipeline: w.pipeline.filter((_, i) => i !== editPassIndex), } - : w - ) + : w, + ), ); setEditModalVisible(false); }; @@ -285,7 +290,7 @@ export default function PyTorchTritonExplorer() { const toggleCollapse = (id) => { setIrWindows((prev) => - prev.map((w) => (w.id === id ? { ...w, collapsed: !w.collapsed } : w)) + prev.map((w) => (w.id === id ? { ...w, collapsed: !w.collapsed } : w)), ); }; @@ -293,14 +298,14 @@ export default function PyTorchTritonExplorer() { const value = event.target.value; setIrWindows((prev) => prev.map((w) => - w.id === id ? { ...w, selectedIR: value, pipeline: [] } : w - ) + w.id === id ? { ...w, selectedIR: value, pipeline: [] } : w, + ), ); }; const generateIR = async (id) => { setIrWindows((prev) => - prev.map((w) => (w.id === id ? { ...w, loading: true } : w)) + prev.map((w) => (w.id === id ? { ...w, loading: true } : w)), ); const irWin = irWindows.find((w) => w.id === id); @@ -338,17 +343,20 @@ export default function PyTorchTritonExplorer() { dump_after_each_opt: irWin.dumpAfterEachOpt, }; - const response = await fetch("http://" + window.location.hostname + ":8000/generate_ir", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify(body), - }); + const response = await fetch( + "http://" + window.location.hostname + ":8000/generate_ir", + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }, + ); const data = await response.json(); setIrWindows((prev) => prev.map((w) => - w.id === id ? { ...w, output: data.output, loading: false } : w - ) + w.id === id ? { ...w, output: data.output, loading: false } : w, + ), ); }; @@ -575,190 +583,69 @@ export default function PyTorchTritonExplorer() { ))} - {(selectedLanguage === "pytorch" || - selectedLanguage == "raw_ir") && - (() => { - const allowTorchMlirOpt = ["torch_script_graph_ir", "torch_mlir", "raw_ir"].includes( - irWin.selectedIR - ); - const allowMlirOpt = [ - "torch_script_graph_ir", - "torch_mlir", - "tosa_mlir", - "linalg_on_tensors_mlir", - "stablehlo_mlir", - "llvm_mlir", - "raw_ir", - ].includes(irWin.selectedIR); - const allowMlirTranslate = [ - "torch_script_graph_ir", - "torch_mlir", - "tosa_mlir", - "linalg_on_tensors_mlir", - "stablehlo_mlir", - "llvm_mlir", - "raw_ir", - ].includes(irWin.selectedIR); - const allowLlvmOpt = true; - const allowLLC = true; - const allowUserTool = true; - - if ( - !allowTorchMlirOpt && - !allowMlirOpt && - !allowMlirTranslate && - !allowLlvmOpt && - !allowUserTool - ) - return null; - - return ( -
- {allowTorchMlirOpt && ( - - )} - {allowMlirOpt && ( - - )} - {allowMlirTranslate && ( - - )} - {allowLlvmOpt && ( - - )} - {allowLLC && ( - - )} - {allowUserTool && ( - - )} + {(() => { + const allowTorchMlirOpt = [ + "torch_script_graph_ir", + "torch_mlir", + "raw_ir", + ].includes(irWin.selectedIR); + const allowMlirOpt = [ + "torch_script_graph_ir", + "torch_mlir", + "tosa_mlir", + "linalg_on_tensors_mlir", + "stablehlo_mlir", + "llvm_mlir", + "raw_ir", + ].includes(irWin.selectedIR); + const allowMlirTranslate = [ + "torch_script_graph_ir", + "torch_mlir", + "tosa_mlir", + "linalg_on_tensors_mlir", + "stablehlo_mlir", + "llvm_mlir", + "raw_ir", + ].includes(irWin.selectedIR); + const allowTritonOpt = [ + "triton_ir", + "triton_gpu_ir", + ].includes(irWin.selectedIR); + const allowTritonLLVMOpt = [ + "triton_ir", + "triton_gpu_ir", + "triton_llvm_ir" + ].includes(irWin.selectedIR); + const allowLlvmOpt = irWin.selectedIR != "triton_nvptx"; + const allowLLC = irWin.selectedIR != "triton_nvptx"; + const allowUserTool = true; + + if ( + !allowTorchMlirOpt && + !allowMlirOpt && + !allowMlirTranslate && + !allowLlvmOpt && + !allowUserTool + ) + return null; + + return ( +
+ {allowTorchMlirOpt && ( -
- ); - })()} - - {(selectedLanguage === "pytorch" || - selectedLanguage == "raw_ir") && - irWin.pipeline.length > 0 && ( + )} + {allowMlirOpt && ( + + )} + {allowMlirTranslate && ( + + )} + {allowTritonOpt && ( + + )} + {allowTritonLLVMOpt && ( + + )} + {allowLlvmOpt && ( + + )} + {allowLLC && ( + + )} + {allowUserTool && ( + + )} + +
+ ); + })()} + + {irWin.pipeline.length > 0 && ( +
-
- - Compilation pipeline: - + Compilation pipeline: + - - {getLabelForIR(irWin.selectedIR)} - - {irWin.pipeline.map((p, i) => { - const preview = - p.flags.length <= 25 - ? p.flags - : `${p.flags.slice(0, 15)}...${p.flags.slice(-10)}`; - return ( - - - → - + + {getLabelForIR(irWin.selectedIR)} + + {irWin.pipeline.map((p, i) => { + const preview = + p.flags.length <= 25 + ? p.flags + : `${p.flags.slice(0, 15)}...${p.flags.slice(-10)}`; + return ( + + + → + + + handleEditPass(irWin.id, i, p.tool, p.flags) + } + style={{ + backgroundColor: "#a5d6a7", + padding: "2px 8px", + borderRadius: "4px", + cursor: "pointer", + fontWeight: "bold", + display: "flex", + flexDirection: "column", + lineHeight: "1.2em", + }} + > + {p.tool} - handleEditPass(irWin.id, i, p.tool, p.flags) - } style={{ - backgroundColor: "#a5d6a7", - padding: "2px 8px", - borderRadius: "4px", - cursor: "pointer", - fontWeight: "bold", - display: "flex", - flexDirection: "column", - lineHeight: "1.2em", + fontSize: "0.8em", + color: "#555", }} > - {p.tool} - - {preview} - + {preview} - - ); - })} -
+ + + ); + })}
- )} +
+ )}