diff --git a/CMakeLists.txt b/CMakeLists.txt index e2008ebfa..1dd46e0fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -135,6 +135,10 @@ function(add_triton_object name) ) if (FLAGTREE_BACKEND STREQUAL "ascend") + set(ASCENDNPU_IR_SRC_DIR ${PROJECT_SOURCE_DIR}/third_party/ascendnpu-ir) + set(ASCENDNPU_IR_BINARY_DIR ${PROJECT_BINARY_DIR}/third_party/ascendnpu-ir) + include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) + include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) set(patched_depends "") foreach(dep ${ARG_DEPENDS}) list(FIND PATCHED_TRITON_DEPENDS "${dep}" index) diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 865bee249..029fc268b 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -70,6 +70,9 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=False): pm.enable_debug() # Add pass here. ascend.passes.convert.add_triton_to_linalg_pipeline(pm) + ascend.passes.convert.add_triton_to_llvm(pm) + ascend.passes.convert.add_triton_to_hfusion(pm) + ascend.passes.convert.add_triton_to_hivm(pm) pm.run(mod) return str(mod) ''' diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp index d323f95af..7260ade61 100644 --- a/third_party/ascend/triton_ascend.cpp +++ b/third_party/ascend/triton_ascend.cpp @@ -4,6 +4,9 @@ #include "mlir/Pass/PassManager.h" #include "passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimental.h" +#include "triton-shared/TritonToHFusion/TritonToHFusion.h" +#include "triton-shared/TritonToHIVM/TritonToHIVM.h" +#include "triton-shared/TritonToLLVM/TritonToLLVM.h" #define PY_SSIZE_T_CLEAN #include @@ -12,10 +15,16 @@ namespace py = pybind11; void init_triton_ascend_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_triton_to_linalg_pipeline", mlir::triton::createTritonToLinalgExperimentalPass); + ADD_PASS_WRAPPER_0("add_triton_to_llvm", + mlir::triton::createTritonToLLVMPass); + ADD_PASS_WRAPPER_0("add_triton_to_hfusion", + mlir::triton::createTritonToHFusionPass); + ADD_PASS_WRAPPER_0("add_triton_to_hivm", + mlir::triton::createTritonToHIVMPass); } // register ascend passes to triton void init_triton_ascend(py::module &&m) { auto passes = m.def_submodule("passes"); init_triton_ascend_passes_convert(passes.def_submodule("convert")); -} \ No newline at end of file +}