Skip to content

Commit a2683b1

Browse files
committed
review changes 08312025
1 parent 18bc550 commit a2683b1

File tree

8 files changed

+49
-47
lines changed

8 files changed

+49
-47
lines changed

CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,10 @@ if(TRITON_SHARED_BUILD_CPU_BACKEND)
5353

5454
Python3::Module
5555
pybind11::headers
56-
${Python3_LIBRARIES}
5756
)
5857
endif()
5958

60-
# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to
59+
# Add symlinks to selected pytest files and the clang-format setting in triton. The tests are imported into triton-shared’s test folder to
6160
# run under triton-shared's conftest configuration, and the clang-format link ensures consistent code style enforcement across both repositories.
6261
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_core.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_CORE)
6362
cmake_path(APPEND CMAKE_CURRENT_SOURCE_DIR "python" "examples" "test_annotations.py" OUTPUT_VARIABLE TRITON_SHARED_TEST_ANNOTATIONS)

backend/compiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def _ttsharedir_to_llir(ttsharedir: str):
8686
pm = ir.pass_manager(context)
8787
pm.enable_debug()
8888
triton_shared.to_llir.add_convert_linalg_to_affine_loops(pm)
89+
# Note: eliminate-empty-tensors fails when there are multiple func.return ops
90+
# in a single kernel which are the results of early returns.
91+
# See python/examples/test_early_return.py for examples.
92+
# We disable this pass for now since performance on CPU isn't the main
93+
# focus at the moment.
94+
# triton_shared.to_llir.add_eliminate_empty_tensors(pm)
8995
triton_shared.to_llir.add_empty_tensor_to_alloc_tensor(pm)
9096
triton_shared.to_llir.add_one_shot_bufferize_with_options(
9197
pm, allow_return_allocs_from_loops=True)
@@ -103,8 +109,12 @@ def _ttsharedir_to_llir(ttsharedir: str):
103109
triton_shared.to_llir.add_finalize_memref_to_llvm(pm)
104110
triton_shared.to_llir.add_convert_func_to_llvm(pm)
105111
triton_shared.to_llir.add_convert_cf_to_llvm(pm)
112+
# Lowering memrefs creates more affine.apply ops.
113+
# Lowering these affine ops again creates further arith ops,
114+
# so we have to run these two passes again here.
106115
triton_shared.to_llir.add_lower_affine(pm)
107116
triton_shared.to_llir.add_convert_arith_to_llvm(pm)
117+
# Remove all unrealized casts created
108118
triton_shared.to_llir.add_reconcile_unrealized_casts(pm)
109119
pm.run(mod)
110120

include/triton-shared/Conversion/TPtrToLLVM/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ include "mlir/Pass/PassBase.td"
55

66
def TPtrToLLVM : Pass<"tptr-to-llvm", "mlir::ModuleOp"> {
77
let summary = "Convert Tptr operations into LLVM";
8-
let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect"];
8+
let dependentDialects = ["mlir::tptr::TPtrDialect", "mlir::LLVM::LLVMDialect", "mlir::ptr::PtrDialect"];
99
}
1010
#endif

include/triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace tptr {
1111

1212
#define GEN_PASS_DECL
1313
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h.inc"
14+
1415
void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
1516
TypeConverter &typeconverter);
1617

lib/Conversion/TPtrToLLVM/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ add_triton_library(TPtrToLLVM
1111
MLIRPass
1212
MLIRTransforms
1313
MLIRSupport
14-
MLIRReconcileUnrealizedCasts
15-
TPtrIR
1614
MLIRDialectUtils
15+
TPtrIR
1716
)

lib/Dialect/TPtr/IR/TPtrOps.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1+
#include "mlir/Interfaces/SideEffectInterfaces.h" // Required for IR/TPtrOps.h.inc
12
#include "mlir/Bytecode/BytecodeOpInterface.h"
2-
#include "mlir/Interfaces/SideEffectInterfaces.h"
33

4+
#include "mlir/IR/OpImplementation.h"
45
#include "mlir/IR/Builders.h"
56
#include "mlir/IR/BuiltinAttributes.h"
67
#include "mlir/IR/BuiltinTypes.h"
7-
#include "mlir/IR/Dialect.h"
88
#include "mlir/IR/MLIRContext.h"
9-
#include "mlir/IR/OpDefinition.h"
10-
#include "mlir/IR/OpImplementation.h"
119
#include "mlir/IR/OperationSupport.h"
10+
#include "mlir/IR/OpDefinition.h"
11+
#include "mlir/IR/Dialect.h"
1212

1313
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
1414
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
15-
#include "mlir/Interfaces/DataLayoutInterfaces.h"
1615

1716
#define GET_OP_CLASSES
1817
#include "triton-shared/Dialect/TPtr/IR/TPtrOps.h.inc"

tools/triton-shared-opt/RegisterTritonSharedDialects.h

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
#pragma once
2+
// Core dialects and passes needed by triton-shared-opt
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
24
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
5+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
36
#include "mlir/Dialect/Func/IR/FuncOps.h"
47
#include "mlir/Dialect/Linalg/IR/Linalg.h"
58
#include "mlir/Dialect/Linalg/Passes.h"
9+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
10+
#include "mlir/Dialect/Math/IR/Math.h"
611
#include "mlir/Dialect/MemRef/IR/MemRef.h"
7-
// #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
12+
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
814
#include "mlir/Dialect/Tensor/IR/Tensor.h"
915

1016
#include "triton/Dialect/Triton/IR/Dialect.h"
1117
#include "triton/Dialect/Triton/Transforms/Passes.h"
1218

1319
#include "triton-shared/Conversion/StructuredToMemref/Passes.h"
20+
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h"
1421
#include "triton-shared/Conversion/TritonArithToLinalg/Passes.h"
1522
#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h"
1623
#include "triton-shared/Conversion/TritonToLinalg/Passes.h"
@@ -22,7 +29,6 @@
2229
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
2330
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
2431
#include "triton-shared/Transform/AddLLVMDebugInfo/Passes.h"
25-
#include "triton-shared/Conversion/TPtrToLLVM/Passes.h"
2632

2733
#include "mlir/InitAllPasses.h"
2834

@@ -43,18 +49,11 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry &registry) {
4349

4450
// TODO: register Triton & TritonGPU passes
4551
registry.insert<
46-
mlir::tptr::TPtrDialect,
47-
mlir::ptr::PtrDialect,
48-
mlir::ttx::TritonTilingExtDialect,
49-
mlir::tts::TritonStructuredDialect,
50-
mlir::triton::TritonDialect,
51-
mlir::cf::ControlFlowDialect, mlir::scf::SCFDialect,
52-
mlir::math::MathDialect, mlir::arith::ArithDialect,
53-
// mlir::gpu::GPUDialect,
54-
mlir::linalg::LinalgDialect,
55-
mlir::func::FuncDialect,
56-
mlir::tensor::TensorDialect,
57-
mlir::memref::MemRefDialect,
58-
mlir::bufferization::BufferizationDialect,
59-
mlir::LLVM::LLVMDialect>();
52+
mlir::tptr::TPtrDialect, mlir::ptr::PtrDialect,
53+
mlir::ttx::TritonTilingExtDialect, mlir::tts::TritonStructuredDialect,
54+
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
55+
mlir::scf::SCFDialect, mlir::math::MathDialect, mlir::arith::ArithDialect,
56+
mlir::linalg::LinalgDialect, mlir::func::FuncDialect,
57+
mlir::tensor::TensorDialect, mlir::memref::MemRefDialect,
58+
mlir::bufferization::BufferizationDialect, mlir::LLVM::LLVMDialect>();
6059
}

triton_shared.cc

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ namespace py = pybind11;
5656
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
5757

5858
// LLVM: Debug
59-
#include "llvm/Support/Debug.h" // Key header file
59+
#include "llvm/Support/Debug.h" // Key header file
6060

6161
// MLIR: Top-level Transforms
6262
#include "mlir/Transforms/Passes.h"
@@ -69,17 +69,17 @@ namespace py = pybind11;
6969
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7070
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
7171

72-
#define ADD_PASS_WRAPPER_0(name, builder) \
72+
#define ADD_PASS_WRAPPER_0(name, builder) \
7373
m.def(name, [](mlir::PassManager &pm) { pm.addPass(builder()); })
7474

75-
#define ADD_PASS_WRAPPER_1(name, builder, ty0) \
76-
m.def(name, \
75+
#define ADD_PASS_WRAPPER_1(name, builder, ty0) \
76+
m.def(name, \
7777
[](mlir::PassManager &pm, ty0 val0) { pm.addPass(builder(val0)); })
7878

79-
#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \
80-
m.def( \
81-
name, \
82-
[](mlir::PassManager &pm, ty0 arg0) { pm.addPass(builder(val0)); }, \
79+
#define ADD_PASS_WRAPPER_1_ARG(name, builder, ty0, arg0, val0) \
80+
m.def( \
81+
name, \
82+
[](mlir::PassManager &pm, ty0 arg0) { pm.addPass(builder(val0)); }, \
8383
py::arg("pm"), py::arg(#arg0) = val0);
8484

8585
// Function to set MLIR/LLVM debug type
@@ -91,7 +91,8 @@ void enable_mlir_debug(const std::string &debug_type) {
9191
void init_to_llvm(py::module &&m) {
9292
using namespace mlir;
9393

94-
// Note: Linalg conversions may not be available in this MLIR version
94+
ADD_PASS_WRAPPER_0("add_eliminate_empty_tensors",
95+
bufferization::createEmptyTensorEliminationPass);
9596
ADD_PASS_WRAPPER_0("add_convert_linalg_to_affine_loops",
9697
createConvertLinalgToAffineLoopsPass);
9798
ADD_PASS_WRAPPER_0("add_empty_tensor_to_alloc_tensor",
@@ -143,18 +144,12 @@ void init_triton_shared_ir(py::module &&m) {
143144
// ::mlir::triton::gpu::TritonGPUDialect,
144145
// ::mlir::triton::instrument::TritonInstrumentDialect,
145146
::mlir::linalg::LinalgDialect,
146-
::mlir::bufferization::BufferizationDialect,
147-
::mlir::tptr::TPtrDialect,
148-
::mlir::math::MathDialect,
149-
::mlir::memref::MemRefDialect,
150-
::mlir::arith::ArithDialect,
151-
::mlir::scf::SCFDialect,
152-
::mlir::vector::VectorDialect,
153-
::mlir::cf::ControlFlowDialect,
154-
::mlir::triton::proton::ProtonDialect,
155-
::mlir::LLVM::LLVMDialect,
156-
::mlir::ub::UBDialect,
157-
::mlir::func::FuncDialect>();
147+
::mlir::bufferization::BufferizationDialect, ::mlir::tptr::TPtrDialect,
148+
::mlir::math::MathDialect, ::mlir::memref::MemRefDialect,
149+
::mlir::arith::ArithDialect, ::mlir::scf::SCFDialect,
150+
::mlir::vector::VectorDialect, ::mlir::cf::ControlFlowDialect,
151+
::mlir::triton::proton::ProtonDialect, ::mlir::LLVM::LLVMDialect,
152+
::mlir::ub::UBDialect, ::mlir::func::FuncDialect>();
158153

159154
// Register interfaces and translations
160155
// ::mlir::registerAllDialects(registry);

0 commit comments

Comments
 (0)