Skip to content

Commit 91ac8d8

Browse files
authored
Introduce triton-ptr-to-memref pass (#211)
This PR introduces the `triton-ptr-to-memref` pass responsible for converting function signature that uses triton ptr to use memref instead. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. Much of this code is copied from the current `StructuredToMemref` pass which will be cleaned up in a later PR. --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
1 parent 986cea8 commit 91ac8d8

File tree

15 files changed

+406
-0
lines changed

15 files changed

+406
-0
lines changed

include/triton-shared/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured)
44
add_subdirectory(TritonArithToLinalg)
55
add_subdirectory(TritonToUnstructured)
66
add_subdirectory(StructuredToMemref)
7+
add_subdirectory(TritonPtrToMemref)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonPtrToMemref)
3+
add_public_tablegen_target(TritonPtrToMemrefConversionPassIncGen)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H
2+
#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES_H
3+
4+
#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
#define GEN_PASS_REGISTRATION
10+
#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc"
11+
12+
} // namespace triton
13+
} // namespace mlir
14+
15+
#endif
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef TRITON_PTR_TO_MEMREF_CONVERSION_PASSES
2+
#define TRITON_PTR_TO_MEMREF_CONVERSION_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def TritonPtrToMemref : Pass<"triton-ptr-to-memref", "mlir::ModuleOp"> {
7+
let summary = "Convert triton pointer to unranked memref";
8+
let constructor = "triton::createTritonPtrToMemrefPass()";
9+
}
10+
11+
#endif
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H
2+
#define TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H
3+
4+
#include "mlir/Pass/Pass.h"
5+
#include "mlir/Transforms/DialectConversion.h"
6+
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
9+
namespace mlir {
10+
namespace triton {
11+
12+
std::unique_ptr<OperationPass<ModuleOp>> createTritonPtrToMemrefPass();
13+
14+
} // namespace triton
15+
} // namespace mlir
16+
17+
#endif // TRITON_CONVERSION_TRITON_PTR_TO_MEMREF_TRITON_PTR_TO_MEMREF_H

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ add_subdirectory(TritonToStructured)
44
add_subdirectory(TritonToUnstructured)
55
add_subdirectory(TritonArithToLinalg)
66
add_subdirectory(StructuredToMemref)
7+
add_subdirectory(TritonPtrToMemref)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_triton_library(TritonPtrToMemref
2+
TritonPtrToMemrefPass.cpp
3+
4+
DEPENDS
5+
TritonPtrToMemrefConversionPassIncGen
6+
7+
LINK_LIBS PUBLIC
8+
MLIRArithDialect
9+
MLIRDialectUtils
10+
MLIRIR
11+
MLIRMathDialect
12+
MLIRPass
13+
MLIRTensorDialect
14+
MLIRTransforms
15+
MLIRSupport
16+
MLIRReconcileUnrealizedCasts
17+
TritonIR
18+
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Copyright (c) Microsoft Corporation.
4+
// Licensed under the MIT license.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include "mlir/Dialect/Arith/IR/Arith.h"
9+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
10+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
11+
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/IR/BuiltinOps.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/MLIRContext.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/IR/Types.h"
19+
#include "mlir/IR/Value.h"
20+
#include "mlir/IR/ValueRange.h"
21+
#include "mlir/Pass/PassManager.h"
22+
#include "mlir/Support/LLVM.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/Passes.h"
25+
#include "triton-shared/Analysis/OpFoldResultUtils.h"
26+
#include "triton-shared/AnalysisStructured/PtrAnalysis.h"
27+
#include "triton-shared/Conversion/TritonPtrToMemref/TritonPtrToMemref.h"
28+
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
29+
30+
#include "triton/Dialect/Triton/IR/Dialect.h"
31+
32+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
33+
#include "triton/Dialect/Triton/IR/Types.h"
34+
35+
#define DEBUG_TYPE "triton-ptr-to-memref"
36+
37+
using namespace mlir;
38+
using namespace triton;
39+
40+
#define GEN_PASS_CLASSES
41+
#include "triton-shared/Conversion/TritonPtrToMemref/Passes.h.inc"
42+
43+
namespace {
44+
45+
class TritonFunctionSignatureConverter : public TypeConverter {
46+
public:
47+
TritonFunctionSignatureConverter() {
48+
// The order of type conversion is important: later ones are tried earlier.
49+
addConversion([](Type type) { return type; });
50+
addConversion([](triton::PointerType ptrType) {
51+
return UnrankedMemRefType::get(ptrType.getPointeeType(),
52+
/*memorySpace=*/0);
53+
});
54+
addConversion([](RankedTensorType tensorType) -> std::optional<Type> {
55+
if (auto ptrType =
56+
dyn_cast<triton::PointerType>(tensorType.getElementType())) {
57+
return MemRefType::get(tensorType.getShape(), ptrType.getPointeeType());
58+
}
59+
return std::nullopt;
60+
});
61+
62+
auto createUnrealizedCast = [&](OpBuilder &builder, Type resultType,
63+
ValueRange inputs,
64+
Location loc) -> std::optional<Value> {
65+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
66+
.getResult(0);
67+
};
68+
addSourceMaterialization(createUnrealizedCast);
69+
addArgumentMaterialization(createUnrealizedCast);
70+
}
71+
};
72+
73+
class TritonPtrToMemrefPass
74+
: public TritonPtrToMemrefBase<TritonPtrToMemrefPass> {
75+
76+
public:
77+
void getDependentDialects(DialectRegistry &registry) const override {
78+
registry
79+
.insert<arith::ArithDialect, math::MathDialect, affine::AffineDialect,
80+
scf::SCFDialect, tensor::TensorDialect, triton::TritonDialect,
81+
tts::TritonStructuredDialect>();
82+
}
83+
84+
void runOnOperation() override {
85+
auto moduleOp = getOperation();
86+
87+
RewritePatternSet patterns(&getContext());
88+
ConversionTarget target(getContext());
89+
TritonFunctionSignatureConverter typeConverter;
90+
91+
// Update function signature and call ops to use memrefs
92+
target.addDynamicallyLegalOp<func::FuncOp, triton::FuncOp>([&](auto op) {
93+
return typeConverter.isSignatureLegal(
94+
cast<FunctionType>(cast<FunctionOpInterface>(op).getFunctionType()));
95+
});
96+
97+
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
98+
return typeConverter.isLegal(op.getResultTypes()) &&
99+
typeConverter.isLegal(op.getOperandTypes());
100+
});
101+
102+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
103+
patterns, typeConverter);
104+
populateFunctionOpInterfaceTypeConversionPattern<triton::FuncOp>(
105+
patterns, typeConverter);
106+
populateCallOpTypeConversionPattern(patterns, typeConverter);
107+
108+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
109+
signalPassFailure();
110+
}
111+
112+
PassManager pm(&getContext(), moduleOp.getOperationName());
113+
pm.addPass(createCanonicalizerPass());
114+
pm.addPass(createCSEPass());
115+
if (failed(runPipeline(pm, getOperation()))) {
116+
signalPassFailure();
117+
}
118+
}
119+
};
120+
} // namespace
121+
122+
std::unique_ptr<OperationPass<ModuleOp>> triton::createTritonPtrToMemrefPass() {
123+
return std::make_unique<TritonPtrToMemrefPass>();
124+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: triton-shared-opt --triton-arith-to-linalg --triton-ptr-to-memref %s | FileCheck %s
2+
3+
module {
4+
tt.func @_sum_combine__fp32(%arg0: !tt.ptr<f32>) -> f32{
5+
%0 = arith.constant 42.0 : f32
6+
tt.return %0 : f32
7+
}
8+
tt.func @test(%arg0: !tt.ptr<f32>) -> f32{
9+
%0 = tt.call @_sum_combine__fp32(%arg0) : (!tt.ptr<f32>) -> f32
10+
tt.return %0 : f32
11+
}
12+
}
13+
14+
// CHECK: module {
15+
// CHECK: func.func @_sum_combine__fp32(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
16+
// CHECK: %cst = arith.constant 4.200000e+01 : f32
17+
// CHECK: return %cst : f32
18+
// CHECK: }
19+
// CHECK: func.func @test(%arg0: memref<*xf32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) -> f32 {
20+
// CHECK: %0 = call @_sum_combine__fp32(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (memref<*xf32>, i32, i32, i32, i32, i32, i32) -> f32
21+
// CHECK: return %0 : f32
22+
// CHECK: }
23+
// CHECK: }
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: triton-shared-opt --triton-ptr-to-memref %s | FileCheck %s
2+
3+
module {
4+
func.func public @add_kernel_01234(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>, %arg3: i32) {
5+
%c1024_i32 = arith.constant 1024 : i32
6+
%0 = tt.get_program_id x : i32
7+
%1 = arith.muli %0, %c1024_i32 : i32
8+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
9+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
10+
%4 = arith.addi %3, %2 : tensor<1024xi32>
11+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32>
12+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
13+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
14+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
15+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
16+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
17+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
18+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
19+
%13 = arith.addf %9, %12 : tensor<1024xf32>
20+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
21+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
22+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
23+
return
24+
}
25+
}
26+
27+
// CHECK: func.func public @add_kernel_01234(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>, %arg3: i32)

0 commit comments

Comments
 (0)