Skip to content

Commit 99b5e29

Browse files
authored
[TensorDescriptor] Fallback to tl.load/store on hardware without TMA (#6753)
This implements a pass for converting tma load/store into legacy loads/stores. This is required for supporting tensor descriptors on hardware that doesn't directly support tensor descriptors. This does not implement: * Host side tensor descriptors - I'll submit this in a follow up PR. * Descriptor reduction operations. * Interop for unsupported tensor descriptors on devices which support tensor descriptors. This updates the (old) CUDA and HIP lowering to use this new pass. Lit tests have been added for the pass and the CUDA tensor descriptor tests that work on hardware have been move to the language folder since they are now supported on other hardware. The HIP lowering is untested as I don't have access to a AMD card. I have tested the CUDA lowering on an A100 machine.
1 parent 26d7722 commit 99b5e29

File tree

13 files changed

+1990
-1277
lines changed

13 files changed

+1990
-1277
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
3+
#include "mlir/Transforms/DialectConversion.h"
4+
5+
namespace mlir::triton {
6+
7+
/**
8+
* @brief Provides helper patterns for converting arith operations using a type
9+
* converter.
10+
*
11+
* Note at of the time of writing this isn't provided in upstream mlir.
12+
*/
13+
void populateArithTypeConversions(const TypeConverter &converter,
14+
RewritePatternSet &patterns);
15+
16+
} // namespace mlir::triton
17+
18+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_ARITH_TYPE_CONVERSION_H_
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
2+
#define TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_
3+
#include "mlir/Transforms/DialectConversion.h"
4+
5+
namespace mlir::triton {
6+
7+
/**
8+
* @brief Provides helper patterns for converting triton function operations
9+
* using a type converter.
10+
*
11+
* Note we cannot use upstream passes for this because they are unaware of
12+
* tt.call and tt.return.
13+
*/
14+
void populateFunctionTypeConversions(const TypeConverter &converter,
15+
RewritePatternSet &patterns);
16+
17+
} // namespace mlir::triton
18+
19+
#endif // TRITON_DIALECT_TRITON_TRANSFORMS_FUNCTION_TYPE_CONVERSION_H_

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
4646
let dependentDialects = ["mlir::triton::TritonDialect"];
4747
}
4848

49+
def TritonRewriteTensorDescriptorToPointer : Pass</*cli-arg*/"triton-rewrite-tensor-descriptor-to-pointer", /*Op*/"mlir::ModuleOp"> {
50+
let summary = "Rewrite load/stores of tensor descriptors into pointer load/stores";
51+
let description = [{
52+
This pass rewrites all load/store semantics initiated by a `tt.make_tensor_descriptor` into pointer semantics. After
53+
this pass, `tt.make_tensor_descriptor` will disappear, and it generates logics to compute the pointer/mask/other
54+
for each load/store.
55+
}];
56+
57+
let dependentDialects = ["mlir::triton::TritonDialect"];
58+
}
59+
4960
def TritonLoopUnroll : Pass</*cli-arg*/"triton-loop-unroll", /*Op*/"mlir::ModuleOp"> {
5061
let summary = "Loop unroller";
5162
let description = [{
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "triton/Dialect/Triton/Transforms/ArithTypeConversion.h"
2+
3+
#include "mlir/Dialect/Arith/IR/Arith.h"
4+
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/IR/PatternMatch.h"
6+
#include "mlir/Support/LLVM.h"
7+
#include "mlir/Transforms/DialectConversion.h"
8+
9+
namespace {
10+
11+
struct RewriteArithSelectOp : mlir::OpConversionPattern<mlir::arith::SelectOp> {
12+
using mlir::OpConversionPattern<mlir::arith::SelectOp>::OpConversionPattern;
13+
14+
mlir::LogicalResult
15+
matchAndRewrite(mlir::arith::SelectOp op, OneToNOpAdaptor adaptor,
16+
mlir::ConversionPatternRewriter &rewriter) const {
17+
// Note we're replacing the select op with an if op because we are
18+
// converting one value into many values.
19+
auto newIf = rewriter.create<mlir::scf::IfOp>(
20+
op.getLoc(), mlir::TypeRange(adaptor.getTrueValue()), op.getCondition(),
21+
true);
22+
// We set the attributes from the op in case the op has any additional
23+
// attributes
24+
newIf->setAttrs(op->getAttrs());
25+
26+
{
27+
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
28+
rewriter.setInsertionPointToStart(newIf.thenBlock());
29+
rewriter.create<mlir::scf::YieldOp>(op->getLoc(), adaptor.getTrueValue());
30+
rewriter.setInsertionPointToStart(newIf.elseBlock());
31+
rewriter.create<mlir::scf::YieldOp>(op->getLoc(),
32+
adaptor.getFalseValue());
33+
}
34+
35+
// Replace the old operation results
36+
rewriter.replaceOpWithMultiple(op, {newIf->getResults()});
37+
38+
return mlir::success();
39+
}
40+
};
41+
42+
} // namespace
43+
namespace mlir::triton {
44+
45+
void populateArithTypeConversions(const TypeConverter &converter,
46+
RewritePatternSet &patterns) {
47+
patterns.add<RewriteArithSelectOp>(converter, patterns.getContext());
48+
}
49+
50+
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ add_triton_library(TritonTransforms
88
LoopUnroll.cpp
99
ReorderBroadcast.cpp
1010
RewriteTensorPointer.cpp
11+
RewriteTensorDescriptorToPointer.cpp
12+
ArithTypeConversion.cpp
13+
FunctionTypeConversion.cpp
1114

1215
DEPENDS
1316
TritonTransformsIncGen
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "triton/Dialect/Triton/Transforms/FunctionTypeConversion.h"
2+
3+
#include "mlir/IR/Value.h"
4+
#include "mlir/Support/LLVM.h"
5+
#include "mlir/Transforms/DialectConversion.h"
6+
#include "triton/Dialect/Triton/IR/Dialect.h"
7+
#include "llvm/ADT/STLExtras.h"
8+
#include "llvm/ADT/SmallVector.h"
9+
10+
#include <cstdlib>
11+
12+
namespace mlir::triton {
13+
14+
namespace {
15+
16+
SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
17+
SmallVector<Value> ret;
18+
for (const auto &vs : values) {
19+
llvm::append_range(ret, vs);
20+
}
21+
return ret;
22+
}
23+
24+
struct CallOpConversion : public OpConversionPattern<CallOp> {
25+
using OpConversionPattern<CallOp>::OpConversionPattern;
26+
27+
LogicalResult
28+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
29+
ConversionPatternRewriter &rewriter) const override {
30+
llvm::SmallVector<std::size_t> resultReplacementGrouping;
31+
llvm::SmallVector<Type> convertedResults;
32+
33+
for (auto type : callOp->getResultTypes()) {
34+
const auto oldNumFlattenedResults = convertedResults.size();
35+
if (failed(getTypeConverter()->convertTypes(type, convertedResults))) {
36+
return failure();
37+
}
38+
resultReplacementGrouping.push_back(convertedResults.size() -
39+
oldNumFlattenedResults);
40+
}
41+
42+
auto newCallOp = rewriter.create<CallOp>(
43+
callOp->getLoc(), callOp.getCallee(), convertedResults,
44+
flattenValues(adaptor.getOperands()));
45+
// Preserve any additional attributes that may have been set on the op
46+
newCallOp->setAttrs(callOp->getAttrs());
47+
48+
SmallVector<ValueRange> replacements;
49+
std::size_t offset = 0;
50+
for (auto groupSize : resultReplacementGrouping) {
51+
replacements.push_back(newCallOp->getResults().slice(offset, groupSize));
52+
offset += groupSize;
53+
}
54+
55+
rewriter.replaceOpWithMultiple(callOp, replacements);
56+
return success();
57+
}
58+
};
59+
60+
struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
61+
using OpConversionPattern<ReturnOp>::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(ReturnOp returnOp, OneToNOpAdaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const override {
66+
auto newReturnOp = rewriter.create<ReturnOp>(
67+
returnOp->getLoc(), flattenValues(adaptor.getOperands()));
68+
// Preserve any additional attributes that may have been set on the op
69+
newReturnOp->setAttrs(returnOp->getAttrs());
70+
71+
rewriter.replaceOp(returnOp, newReturnOp);
72+
return success();
73+
}
74+
};
75+
76+
} // namespace
77+
78+
void populateFunctionTypeConversions(const TypeConverter &converter,
79+
RewritePatternSet &patterns) {
80+
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81+
patterns, converter);
82+
patterns.add<CallOpConversion, ReturnOpConversion>(converter,
83+
patterns.getContext());
84+
}
85+
86+
} // namespace mlir::triton

0 commit comments

Comments
 (0)