Skip to content

Commit 986cea8

Browse files
authored
Introduce triton-to-unstructured pass (#210)
This PR introduces the `triton-to-unstructured` pass which is the first step towards allowing triton-shared to compile pointer sequences that cannot be analyzed by `triton-to-structured` (gather / scatter). This pass attempts to lower all loads and stores of unstructured pointers to tts.gather or tts.scatter that take a single base, a tensor of offsets, an optional tensor of mask values, and a default value in case of load. In addition, all pointer-producing ops will be eliminated and replaced by offset-producing ops. tts.gather and tts.scatter will use the pointer directly from the kernel arguments as opposed to pointer produced by ops such as tt.addptr and tt.splat. Example: ```mlir module { tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<5> : tensor<64xi32> %cst_0 = arith.constant dense<10> : tensor<64xi32> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %1 = arith.divsi %0, %cst_0 : tensor<64xi32> %2 = arith.addi %1, %cst : tensor<64xi32> %3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %5 = tt.load %4 : tensor<64x!tt.ptr<f32>> %6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> tt.store %7, %5 : tensor<64x!tt.ptr<f32>> tt.return } } ``` becomes ```mlir module { tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<5> : tensor<64xi32> %cst_0 = arith.constant dense<10> : tensor<64xi32> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %1 = arith.divsi %0, %cst_0 : tensor<64xi32> %2 = arith.addi %1, %cst : tensor<64xi32> %3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32> tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>) tt.return } } ``` Current assumptions and limitations: - For simplicity, the pass assumes that gather / scatter operations load / store from / to a single base with a tensor of random offsets. As a result, the following triton program would not work: ```python @triton.jit def gather_simple(in0, in1, out0): offs = tl.arange(0, 8) in0_ptrs = in0 + offs in1_ptrs = in1 + offs ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True) c = tl.load(ptrs) out_offs = tl.arange(0, 16) tl.store(out0 + out_offs, c) ``` In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the `cat` operation. For more details on the algorithm, see the `TritonToUnstructuredPass.cpp` file. # Future work Future work may include scaling the algorithm to support multiple bases -- one possible solution is to let tts.gather and tts.scatter take in an additional tensor of base pointers corresponding to the tensor of offsets. But because we do not want pointer-producing ops to be present after this pass, we can use a tensor of index where each element indicates the index of the pointer argument to be used. The drawback is a gather or scatter operation now needs one extract lookup to get the base which will affect performance. --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
1 parent 8b9f5dd commit 986cea8

28 files changed

+1809
-89
lines changed

include/triton-shared/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ add_subdirectory(TritonToLinalg)
22
add_subdirectory(TritonToLinalgExperimental)
33
add_subdirectory(TritonToStructured)
44
add_subdirectory(TritonArithToLinalg)
5+
add_subdirectory(TritonToUnstructured)
56
add_subdirectory(StructuredToMemref)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructured)
3+
add_public_tablegen_target(TritonToUnstructuredConversionPassIncGen)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H
2+
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H
3+
4+
#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
9+
#define GEN_PASS_REGISTRATION
10+
#include "triton-shared/Conversion/TritonToUnstructured/Passes.h.inc"
11+
12+
} // namespace triton
13+
} // namespace mlir
14+
15+
#endif
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES
2+
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def TritonToUnstructured : Pass<"triton-to-unstructured", "mlir::ModuleOp"> {
7+
let summary = "Transforms tt.addptr ops into offset accumulation ops";
8+
let constructor = "triton::createTritonToUnstructuredPass()";
9+
let options = [
10+
Option<"offsetBitWidth", "offset-bit-width", "size_t", /*default*/"32",
11+
"Bitwidth used for the starting offset of each pointer">
12+
];
13+
}
14+
15+
#endif
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
2+
#define TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
3+
4+
#include "mlir/Pass/Pass.h"
5+
#include "mlir/Transforms/DialectConversion.h"
6+
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
9+
namespace mlir {
10+
namespace triton {
11+
12+
std::unique_ptr<OperationPass<ModuleOp>> createTritonToUnstructuredPass();
13+
14+
} // namespace triton
15+
} // namespace mlir
16+
17+
#endif // TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H

include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_
22
#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_
33

4-
#include "mlir/IR/BuiltinOps.h"
5-
#include "mlir/IR/BuiltinTypes.h"
64
#include "mlir/IR/Dialect.h"
75
#include "mlir/IR/MLIRContext.h"
86
#include "mlir/IR/OpDefinition.h"
9-
#include "mlir/IR/SymbolTable.h"
10-
#include "mlir/IR/TypeSupport.h"
11-
#include "mlir/IR/Types.h"
12-
#include "mlir/Interfaces/SideEffectInterfaces.h"
7+
138
#include "triton/Dialect/Triton/IR/Dialect.h"
149

15-
#include "mlir/IR/Dialect.h"
10+
namespace mlir {
11+
namespace tts {
12+
namespace utils {
13+
mlir::Value getScalarValue(mlir::Value operand, mlir::Location loc,
14+
mlir::OpBuilder &builder);
15+
}
16+
} // namespace tts
17+
} // namespace mlir
1618

1719
//===----------------------------------------------------------------------===//
1820
// TritonStructured Operations

include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def TTS_MakeTensorPtrOp
120120
//let hasCanonicalizer = 1;
121121
}
122122

123-
// SameVariadicResultSize
124-
// AttrSizedResultSegments
125123
def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> {
126124
let summary = "Placeholder for the structured pointer states computed during PtrAnalysis.";
127125
let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites.";
@@ -145,6 +143,51 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe
145143
let hasVerifier = 1;
146144
}
147145

146+
def TTS_GatherOp : TTS_Op<"gather", [
147+
MemoryEffects<[MemRead]>,
148+
AttrSizedOperandSegments,
149+
OptionalTypesMatchWith<"mask type matches ptr type", "offset", "mask", "triton::getI1SameShape($_self)">,
150+
OptionalTypesMatchWith<"other matches ptr type", "ptr", "other", "triton::getPointeeType($_self)">
151+
]> {
152+
let summary = "optionally load data from in memory to fill a portion of the tensor";
153+
154+
let arguments = (
155+
ins
156+
TT_Ptr:$ptr,
157+
TT_IntLike:$offset,
158+
Optional<TT_BoolLike>:$mask,
159+
Optional<TT_Type>:$other
160+
);
161+
162+
let results = (outs TT_Type:$result);
163+
164+
let assemblyFormat = [{
165+
$ptr `[` $offset `]` (`mask` `=` $mask^)? (`default` `=` $other^)?
166+
attr-dict `:` `(` type($ptr) `,` type($offset) `)` `->` type($result)
167+
}];
168+
}
169+
170+
def TTS_ScatterOp : TTS_Op<"scatter", [
171+
MemoryEffects<[MemWrite]>,
172+
OptionalTypesMatchWith<"mask type matches offset type", "offset", "mask",
173+
"triton::getI1SameShape($_self)">
174+
]> {
175+
let summary = "optionally store data from in memory to fill a portion of the tensor";
176+
177+
let arguments = (
178+
ins
179+
TT_Ptr:$ptr,
180+
TT_IntLike:$offset,
181+
TT_Type:$value,
182+
Optional<TT_BoolLike>:$mask
183+
);
184+
185+
let assemblyFormat = [{
186+
$value `into` $ptr `[` $offset `]` (`mask` `=` $mask^)?
187+
attr-dict `:` type($value) `into` ` ` `(` type($ptr) `,` type($offset) `)`
188+
}];
189+
}
190+
148191
def TTS_LoadOp : TTS_Op<"load", [
149192
MemoryEffects<[MemRead]>,
150193
AttrSizedOperandSegments
@@ -170,7 +213,7 @@ def TTS_LoadOp : TTS_Op<"load", [
170213
}
171214

172215
bool hasMask() {
173-
return !getStaticMaskDims().empty();
216+
return !getMixedMaskDims().empty();
174217
}
175218
}];
176219

@@ -182,7 +225,7 @@ def TTS_LoadOp : TTS_Op<"load", [
182225
def TTS_StoreOp : TTS_Op<"store", [
183226
MemoryEffects<[MemWrite]>
184227
]> {
185-
let summary = "optionally load data from in memory to fill a portion of the tensor";
228+
let summary = "optionally store data from in memory to fill a portion of the tensor";
186229

187230
let arguments = (ins TT_PtrLike:$ptr,
188231
TT_Tensor:$value,
@@ -201,7 +244,7 @@ def TTS_StoreOp : TTS_Op<"store", [
201244
}
202245

203246
bool hasMask() {
204-
return !getStaticMaskDims().empty();
247+
return !getMixedMaskDims().empty();
205248
}
206249
}];
207250

lib/AnalysisStructured/PtrAnalysis.cpp

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
#include "triton-shared/AnalysisStructured/PtrAnalysis.h"
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/SCF/IR/SCF.h"
11-
#include "mlir/Dialect/Utils/StaticValueUtils.h"
12-
#include "mlir/IR/BuiltinOps.h"
1311
#include "mlir/IR/BuiltinTypes.h"
1412
#include "mlir/IR/Value.h"
15-
#include "mlir/IR/ValueRange.h"
1613
#include "mlir/IR/Visitors.h"
1714
#include "mlir/Support/LLVM.h"
1815
#include "mlir/Support/LogicalResult.h"
@@ -33,7 +30,6 @@
3330
#include "llvm/Support/LogicalResult.h"
3431
#include <cassert>
3532
#include <cstddef>
36-
#include <functional>
3733
#include <optional>
3834
#include <queue>
3935
#include <string>
@@ -42,74 +38,6 @@
4238

4339
namespace mlir {
4440

45-
// Extract a scalar value from v.
46-
// If v is a scalar, return that directly. Otherwise, parse through operations
47-
// (currently only support splat, sitofp, and truncf) that produce it to
48-
// extract the underlying scalar value. We then reconstruct the chain of
49-
// operations that can produce this constant with the original type. If no
50-
// scalar value can be extracted, a nullptr is returned.
51-
static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) {
52-
SmallVector<Operation *> ops;
53-
54-
auto reconstructScalarValue = [&](Value src) {
55-
for (auto op = ops.rbegin(); op != ops.rend(); ++op) {
56-
src = TypeSwitch<Operation *, Value>(*op)
57-
.Case<arith::SIToFPOp>([&](Operation *op) {
58-
auto resType = op->getResults()[0].getType();
59-
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
60-
resType = shapedType.getElementType();
61-
}
62-
return builder.create<arith::SIToFPOp>(loc, resType, src);
63-
})
64-
.Case<arith::TruncFOp>([&](Operation *op) {
65-
auto resType = op->getResults()[0].getType();
66-
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
67-
resType = shapedType.getElementType();
68-
}
69-
return builder.create<arith::TruncFOp>(loc, resType, src);
70-
})
71-
.Default([](Operation *op) {
72-
llvm_unreachable("unsupported op in generating ");
73-
return nullptr;
74-
});
75-
}
76-
return src;
77-
};
78-
79-
while (true) {
80-
if (!dyn_cast<ShapedType>(operand.getType())) {
81-
return reconstructScalarValue(operand);
82-
} else if (auto op = operand.getDefiningOp<arith::ConstantOp>()) {
83-
if (auto attr = dyn_cast<DenseElementsAttr>(op.getValue())) {
84-
if (!attr.isSplat()) {
85-
InFlightDiagnostic diag = emitError(loc)
86-
<< "other value used in masked load "
87-
"produced by unsupported instruction";
88-
return nullptr;
89-
}
90-
auto elemValue = attr.getSplatValue<Attribute>();
91-
auto constOp = arith::ConstantOp::materialize(
92-
builder, elemValue, attr.getElementType(), op.getLoc());
93-
return reconstructScalarValue(constOp.getResult());
94-
}
95-
} else if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
96-
operand = op.getSrc();
97-
} else if (auto op = operand.getDefiningOp<arith::SIToFPOp>()) {
98-
ops.push_back(op.getOperation());
99-
operand = op.getIn();
100-
} else if (auto op = operand.getDefiningOp<arith::TruncFOp>()) {
101-
ops.push_back(op.getOperation());
102-
operand = op.getIn();
103-
} else {
104-
InFlightDiagnostic diag = emitError(loc)
105-
<< "other value used in masked load produced "
106-
"by unsupported instruction";
107-
return nullptr;
108-
}
109-
}
110-
return nullptr;
111-
}
112-
11341
namespace tts {
11442

11543
int32_t PtrState::getRank() const {
@@ -1159,7 +1087,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op,
11591087
if (other) {
11601088
assert(mask && "other value used while no masks are specified");
11611089

1162-
scalarOther = getScalarValue(other, loc, builder);
1090+
scalarOther = utils::getScalarValue(other, loc, builder);
11631091
if (!scalarOther) {
11641092
op->emitRemark("other value used in masked load produced by "
11651093
"unsupported instruction");

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_subdirectory(TritonToLinalg)
22
add_subdirectory(TritonToLinalgExperimental)
33
add_subdirectory(TritonToStructured)
4+
add_subdirectory(TritonToUnstructured)
45
add_subdirectory(TritonArithToLinalg)
56
add_subdirectory(StructuredToMemref)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
add_triton_library(TritonToUnstructured
2+
TritonToUnstructuredPass.cpp
3+
4+
DEPENDS
5+
TritonStructuredTableGen
6+
TritonToUnstructuredConversionPassIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRArithDialect
10+
MLIRDialectUtils
11+
MLIRIR
12+
MLIRMathDialect
13+
MLIRPass
14+
MLIRTensorDialect
15+
MLIRTransforms
16+
MLIRSupport
17+
MLIRReconcileUnrealizedCasts
18+
TritonIR
19+
TritonTransforms
20+
TritonSharedAnalysisStructured
21+
TritonStructuredIR
22+
)

0 commit comments

Comments
 (0)