Skip to content

Commit 9147628

Browse files
authored
Add triton-opt and triton-llvm-opt for Triton (#8)
Also allow to use opt, llc and custom tools for Triron IRs. mlir-opt remains disabled for Triton as it seem to be unusable without --allow-unregistered-dialect. Yet a user can call mlir-opt as unspecified custom tool (if anybody reading this commit message disagrees - please don't hesitate and comment on the PR or open an issue). TODO: Triton support is still very raw, need to dump IR not from cache dir. With that fix - merge triton_llvm_ir and llvm_ir IR targets. Explore AMD and Intel lowering for Triton. opt tool bar starts to get overloaded - explicit listing of tools in MLIR ecosystem supported by the app is good for novices like myself, but for advanced users it should be a command line. git blame note: I've applied npx prettier --write without preserving intermediate stage, so diff became bigger, then I have anticipated.
1 parent 529bbed commit 9147628

9 files changed

+422
-287
lines changed

backend/server.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
TORCH_MLIR_OPT_PATH = os.environ.get("TORCH_MLIR_OPT_PATH", "")
4545
LLVM_BIN_PATH = os.environ.get("LLVM_BIN_PATH", "")
46+
TRITON_OPT_PATH = os.environ.get("TRITON_OPT_PATH", "")
4647

4748

4849
class CodeRequest(BaseModel):
@@ -55,6 +56,8 @@ class CodeRequest(BaseModel):
5556
mlir_translate: Optional[str] = ""
5657
llvm_opt: Optional[str] = ""
5758
llc: Optional[str] = ""
59+
triton_opt: Optional[str] = ""
60+
triton_llvm_opt: Optional[str] = ""
5861
user_tool: Optional[str] = ""
5962
dump_after_each_opt: Optional[bool] = False
6063

@@ -156,6 +159,10 @@ def apply_optional_passes(
156159
tool_path = LLVM_BIN_PATH + "opt"
157160
elif tool == "llc":
158161
tool_path = LLVM_BIN_PATH + "llc"
162+
elif tool == "triton-opt":
163+
tool_path = TRITON_OPT_PATH + "triton-opt"
164+
elif tool == "triton-llvm-opt":
165+
tool_path = TRITON_OPT_PATH + "triton-llvm-opt"
159166
elif tool == "user-tool":
160167
tokens = flags.strip().split()
161168
if not tokens:
@@ -341,7 +348,9 @@ def generate_llvm_ir(
341348

342349

343350
# TODO: Figure out static compilation.
344-
def compile_triton_ir(code: str, ir_type: str) -> str:
351+
def compile_triton_ir(
352+
code: str, ir_type: str, pipeline: List[Tuple[str, str]], dump_each: bool
353+
) -> str:
345354
try:
346355
code_hash = hash_code(code)
347356
cache_info = cached_triton_runs.get(code_hash)
@@ -403,7 +412,8 @@ def compile_triton_ir(code: str, ir_type: str) -> str:
403412

404413
cached_triton_runs[code_hash]["active_users"] += 1
405414

406-
return "\n\n".join(output_parts)
415+
ir_dump = "\n\n".join(output_parts)
416+
return apply_optional_passes(ir_dump, pipeline, dump_each)
407417

408418
except Exception as e:
409419
return f"Error compiling Triton IR: {str(e)}"
@@ -427,6 +437,12 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]:
427437
if request.llc:
428438
for stage in request.llc.split("&&"):
429439
pipeline.append(("llc", stage.strip()))
440+
if request.llc:
441+
for stage in request.triton_opt.split("&&"):
442+
pipeline.append(("triton_opt", stage.strip()))
443+
if request.llc:
444+
for stage in request.triton_llvm_opt.split("&&"):
445+
pipeline.append(("triton_llvm_opt", stage.strip()))
430446
if request.user_tool:
431447
for stage in request.user_tool.split("&&"):
432448
pipeline.append(("user-tool", stage.strip()))
@@ -437,7 +453,10 @@ def build_pipeline(request: CodeRequest) -> List[Tuple[str, str]]:
437453
def process_model(request: CodeRequest) -> str:
438454
try:
439455
if request.ir_type.startswith("triton"):
440-
return compile_triton_ir(request.code, request.ir_type)
456+
pipeline = build_pipeline(request)
457+
return compile_triton_ir(
458+
request.code, request.ir_type, pipeline, request.dump_after_each_opt
459+
)
441460

442461
if request.ir_type == "raw_ir" and request.selected_language == "pytorch":
443462
# Execute user Python, capture stdout.

0 commit comments

Comments
 (0)