diff --git a/backend/server.py b/backend/server.py
index a84e64f..9743156 100644
--- a/backend/server.py
+++ b/backend/server.py
@@ -43,6 +43,7 @@
TORCH_MLIR_OPT_PATH = os.environ.get("TORCH_MLIR_OPT_PATH", "")
LLVM_BIN_PATH = os.environ.get("LLVM_BIN_PATH", "")
+TRITON_OPT_PATH = os.environ.get("TRITON_OPT_PATH", "")
class CodeRequest(BaseModel):
@@ -55,6 +56,8 @@ class CodeRequest(BaseModel):
mlir_translate: Optional[str] = ""
llvm_opt: Optional[str] = ""
llc: Optional[str] = ""
+ triton_opt: Optional[str] = ""
+ triton_llvm_opt: Optional[str] = ""
user_tool: Optional[str] = ""
dump_after_each_opt: Optional[bool] = False
@@ -156,6 +159,10 @@ def apply_optional_passes(
tool_path = LLVM_BIN_PATH + "opt"
elif tool == "llc":
tool_path = LLVM_BIN_PATH + "llc"
+ elif tool == "triton-opt":
+ tool_path = TRITON_OPT_PATH + "triton-opt"
+ elif tool == "triton-llvm-opt":
+ tool_path = TRITON_OPT_PATH + "triton-llvm-opt"
elif tool == "user-tool":
tokens = flags.strip().split()
if not tokens:
@@ -341,7 +348,9 @@ def generate_llvm_ir(
# TODO: Figure out static compilation.
-def compile_triton_ir(code: str, ir_type: str) -> str:
+def compile_triton_ir(
+ code: str, ir_type: str, pipeline: List[Tuple[str, str]], dump_each: bool
+) -> str:
try:
code_hash = hash_code(code)
cache_info = cached_triton_runs.get(code_hash)
@@ -403,7 +412,8 @@ def compile_triton_ir(code: str, ir_type: str) -> str:
cached_triton_runs[code_hash]["active_users"] += 1
- return "\n\n".join(output_parts)
+ ir_dump = "\n\n".join(output_parts)
+ return apply_optional_passes(ir_dump, pipeline, dump_each)
except Exception as e:
return f"Error compiling Triton IR: {str(e)}"
@@ -427,6 +437,12 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]:
if request.llc:
for stage in request.llc.split("&&"):
pipeline.append(("llc", stage.strip()))
+ if request.llc:
+ for stage in request.triton_opt.split("&&"):
+ pipeline.append(("triton_opt", stage.strip()))
+ if request.llc:
+ for stage in request.triton_llvm_opt.split("&&"):
+ pipeline.append(("triton_llvm_opt", stage.strip()))
if request.user_tool:
for stage in request.user_tool.split("&&"):
pipeline.append(("user-tool", stage.strip()))
@@ -437,7 +453,10 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]:
def process_model(request: CodeRequest) -> str:
try:
if request.ir_type.startswith("triton"):
- return compile_triton_ir(request.code, request.ir_type)
+ pipeline = build_pipeline(request)
+ return compile_triton_ir(
+ request.code, request.ir_type, pipeline, request.dump_after_each_opt
+ )
if request.ir_type == "raw_ir" and request.selected_language == "pytorch":
# Execute user Python, capture stdout.
diff --git a/src/app/page.js b/src/app/page.js
index 6b43532..8f91e25 100644
--- a/src/app/page.js
+++ b/src/app/page.js
@@ -84,6 +84,9 @@ const pytorchIROptions = [
const tritonIROptions = [
{ value: "triton_ir", label: "Triton IR" },
{ value: "triton_gpu_ir", label: "Triton GPU IR" },
+ // FIXME: 'Triton' LLVM IR is basically LLVM IR. Right now it's separated
+ // due to the fact, that all of Triron IRs are just dumps from the cache dir
+ // after model execution.
{ value: "triton_llvm_ir", label: "LLVM IR" },
{ value: "triton_nvptx", label: "NVPTX" },
];
@@ -168,8 +171,10 @@ export default function PyTorchTritonExplorer() {
if (flags !== null) {
setIrWindows((prev) =>
prev.map((w) =>
- w.id === id ? { ...w, pipeline: [...w.pipeline, { tool, flags }] } : w
- )
+ w.id === id
+ ? { ...w, pipeline: [...w.pipeline, { tool, flags }] }
+ : w,
+ ),
);
}
};
@@ -185,8 +190,8 @@ export default function PyTorchTritonExplorer() {
{ tool: currentTool, flags: currentFlags },
],
}
- : w
- )
+ : w,
+ ),
);
setModalVisible(false);
};
@@ -206,11 +211,11 @@ export default function PyTorchTritonExplorer() {
? {
...w,
pipeline: w.pipeline.map((p, i) =>
- i === editPassIndex ? { tool: editTool, flags: editFlags } : p
+ i === editPassIndex ? { tool: editTool, flags: editFlags } : p,
),
}
- : w
- )
+ : w,
+ ),
);
setEditModalVisible(false);
};
@@ -223,8 +228,8 @@ export default function PyTorchTritonExplorer() {
...w,
pipeline: w.pipeline.filter((_, i) => i !== editPassIndex),
}
- : w
- )
+ : w,
+ ),
);
setEditModalVisible(false);
};
@@ -285,7 +290,7 @@ export default function PyTorchTritonExplorer() {
const toggleCollapse = (id) => {
setIrWindows((prev) =>
- prev.map((w) => (w.id === id ? { ...w, collapsed: !w.collapsed } : w))
+ prev.map((w) => (w.id === id ? { ...w, collapsed: !w.collapsed } : w)),
);
};
@@ -293,14 +298,14 @@ export default function PyTorchTritonExplorer() {
const value = event.target.value;
setIrWindows((prev) =>
prev.map((w) =>
- w.id === id ? { ...w, selectedIR: value, pipeline: [] } : w
- )
+ w.id === id ? { ...w, selectedIR: value, pipeline: [] } : w,
+ ),
);
};
const generateIR = async (id) => {
setIrWindows((prev) =>
- prev.map((w) => (w.id === id ? { ...w, loading: true } : w))
+ prev.map((w) => (w.id === id ? { ...w, loading: true } : w)),
);
const irWin = irWindows.find((w) => w.id === id);
@@ -338,17 +343,20 @@ export default function PyTorchTritonExplorer() {
dump_after_each_opt: irWin.dumpAfterEachOpt,
};
- const response = await fetch("http://" + window.location.hostname + ":8000/generate_ir", {
- method: "POST",
- headers: { "Content-Type": "application/json" },
- body: JSON.stringify(body),
- });
+ const response = await fetch(
+ "http://" + window.location.hostname + ":8000/generate_ir",
+ {
+ method: "POST",
+ headers: { "Content-Type": "application/json" },
+ body: JSON.stringify(body),
+ },
+ );
const data = await response.json();
setIrWindows((prev) =>
prev.map((w) =>
- w.id === id ? { ...w, output: data.output, loading: false } : w
- )
+ w.id === id ? { ...w, output: data.output, loading: false } : w,
+ ),
);
};
@@ -575,190 +583,69 @@ export default function PyTorchTritonExplorer() {
))}
- {(selectedLanguage === "pytorch" ||
- selectedLanguage == "raw_ir") &&
- (() => {
- const allowTorchMlirOpt = ["torch_script_graph_ir", "torch_mlir", "raw_ir"].includes(
- irWin.selectedIR
- );
- const allowMlirOpt = [
- "torch_script_graph_ir",
- "torch_mlir",
- "tosa_mlir",
- "linalg_on_tensors_mlir",
- "stablehlo_mlir",
- "llvm_mlir",
- "raw_ir",
- ].includes(irWin.selectedIR);
- const allowMlirTranslate = [
- "torch_script_graph_ir",
- "torch_mlir",
- "tosa_mlir",
- "linalg_on_tensors_mlir",
- "stablehlo_mlir",
- "llvm_mlir",
- "raw_ir",
- ].includes(irWin.selectedIR);
- const allowLlvmOpt = true;
- const allowLLC = true;
- const allowUserTool = true;
-
- if (
- !allowTorchMlirOpt &&
- !allowMlirOpt &&
- !allowMlirTranslate &&
- !allowLlvmOpt &&
- !allowUserTool
- )
- return null;
-
- return (
-
- {allowTorchMlirOpt && (
-
- )}
- {allowMlirOpt && (
-
- )}
- {allowMlirTranslate && (
-
- )}
- {allowLlvmOpt && (
-
- )}
- {allowLLC && (
-
- )}
- {allowUserTool && (
-
- )}
+ {(() => {
+ const allowTorchMlirOpt = [
+ "torch_script_graph_ir",
+ "torch_mlir",
+ "raw_ir",
+ ].includes(irWin.selectedIR);
+ const allowMlirOpt = [
+ "torch_script_graph_ir",
+ "torch_mlir",
+ "tosa_mlir",
+ "linalg_on_tensors_mlir",
+ "stablehlo_mlir",
+ "llvm_mlir",
+ "raw_ir",
+ ].includes(irWin.selectedIR);
+ const allowMlirTranslate = [
+ "torch_script_graph_ir",
+ "torch_mlir",
+ "tosa_mlir",
+ "linalg_on_tensors_mlir",
+ "stablehlo_mlir",
+ "llvm_mlir",
+ "raw_ir",
+ ].includes(irWin.selectedIR);
+ const allowTritonOpt = [
+ "triton_ir",
+ "triton_gpu_ir",
+ ].includes(irWin.selectedIR);
+ const allowTritonLLVMOpt = [
+ "triton_ir",
+ "triton_gpu_ir",
+ "triton_llvm_ir"
+ ].includes(irWin.selectedIR);
+ const allowLlvmOpt = irWin.selectedIR != "triton_nvptx";
+ const allowLLC = irWin.selectedIR != "triton_nvptx";
+ const allowUserTool = true;
+
+ if (
+ !allowTorchMlirOpt &&
+ !allowMlirOpt &&
+ !allowMlirTranslate &&
+ !allowLlvmOpt &&
+ !allowUserTool
+ )
+ return null;
+
+ return (
+
+ {allowTorchMlirOpt && (
-
- );
- })()}
-
- {(selectedLanguage === "pytorch" ||
- selectedLanguage == "raw_ir") &&
- irWin.pipeline.length > 0 && (
+ )}
+ {allowMlirOpt && (
+
+ )}
+ {allowMlirTranslate && (
+
+ )}
+ {allowTritonOpt && (
+
+ )}
+ {allowTritonLLVMOpt && (
+
+ )}
+ {allowLlvmOpt && (
+
+ )}
+ {allowLLC && (
+
+ )}
+ {allowUserTool && (
+
+ )}
+
+
+ );
+ })()}
+
+ {irWin.pipeline.length > 0 && (
+
-
-
- Compilation pipeline:
-
+ Compilation pipeline:
+
-
- {getLabelForIR(irWin.selectedIR)}
-
- {irWin.pipeline.map((p, i) => {
- const preview =
- p.flags.length <= 25
- ? p.flags
- : `${p.flags.slice(0, 15)}...${p.flags.slice(-10)}`;
- return (
-
-
- →
-
+
+ {getLabelForIR(irWin.selectedIR)}
+
+ {irWin.pipeline.map((p, i) => {
+ const preview =
+ p.flags.length <= 25
+ ? p.flags
+ : `${p.flags.slice(0, 15)}...${p.flags.slice(-10)}`;
+ return (
+
+
+ →
+
+
+ handleEditPass(irWin.id, i, p.tool, p.flags)
+ }
+ style={{
+ backgroundColor: "#a5d6a7",
+ padding: "2px 8px",
+ borderRadius: "4px",
+ cursor: "pointer",
+ fontWeight: "bold",
+ display: "flex",
+ flexDirection: "column",
+ lineHeight: "1.2em",
+ }}
+ >
+ {p.tool}
- handleEditPass(irWin.id, i, p.tool, p.flags)
- }
style={{
- backgroundColor: "#a5d6a7",
- padding: "2px 8px",
- borderRadius: "4px",
- cursor: "pointer",
- fontWeight: "bold",
- display: "flex",
- flexDirection: "column",
- lineHeight: "1.2em",
+ fontSize: "0.8em",
+ color: "#555",
}}
>
- {p.tool}
-
- {preview}
-
+ {preview}
-
- );
- })}
-
+
+
+ );
+ })}
- )}
+
+ )}