Skip to content

Commit 6ce6d5b

Browse files
Merge commit 'f47cc3eaaa11cf87ffd93127a5d57eed907bdcd5'
2 parents b0ddc4b + f47cc3e commit 6ce6d5b

File tree

28 files changed

+518
-297
lines changed

28 files changed

+518
-297
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,14 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
232232
- `MLIR_ENABLE_TIMING` dumps the timing information for each MLIR pass.
233233
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
234234
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
235-
- `MLIR_ENABLE_DIAGNOSTICS` enables dumping the stack trace and the related IR operation of diagnostics (e.g., errors and warnings).
236-
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
235+
- `MLIR_ENABLE_DIAGNOSTICS=<comma-separated>` controls diagnostic emission in MLIR.
236+
Options are: `warnings`, `remarks`, `stacktraces`, `operations`.
237+
Use comma-separated values to customize output. For example,
238+
`MLIR_ENABLE_DIAGNOSTICS=remarks,operations` enables remarks and IR operations,
239+
while `MLIR_ENABLE_DIAGNOSTICS=warnings,stacktraces` enables warnings with
240+
stacktraces. By default, only errors are shown. Setting `warnings` includes
241+
errors and warnings; `remarks` includes errors, warnings, and remarks.
242+
- `MLIR_ENABLE_REMARK` is deprecated. Please use `MLIR_ENABLE_DIAGNOSTICS=remarks`.
237243
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
238244
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
239245
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H
2+
#define TRITON_CONVERSION_FMA_DOT_UTILITY_H
3+
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "mlir/Support/LLVM.h"
6+
#include "mlir/Transforms/DialectConversion.h"
7+
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
9+
namespace mlir::triton::gpu {
10+
11+
/// Abstract interface for scalar multiplication of Value vectors.
12+
///
13+
/// Enable generation of hardware specific code in different backends.
14+
class FMAVectorMultiplier {
15+
public:
16+
/// \returns scalar product of two arrays, plus c: a·b + c
17+
virtual Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
18+
Value c) = 0;
19+
20+
virtual ~FMAVectorMultiplier() = default;
21+
};
22+
23+
/// Implements a framework for FMA dot conversion to llvm.
24+
///
25+
/// This function implements architecture independent part of FMA dot
26+
/// conversion and calls "multiplier" object, which is defined by caller
27+
/// and implements architecture dependant part of conversion.
28+
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
29+
const LLVMTypeConverter *typeConverter,
30+
ConversionPatternRewriter &rewriter,
31+
FMAVectorMultiplier &multiplier);
32+
33+
} // namespace mlir::triton::gpu
34+
35+
#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H

include/triton/Dialect/Triton/IR/Traits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mlir/IR/BuiltinTypes.h"
55
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Interfaces/InferTypeOpInterface.h"
67
#include "mlir/Support/LogicalResult.h"
78
#include "triton/Dialect/Triton/IR/Types.h"
89

@@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);
2728

2829
LogicalResult verifySameOperandsEncoding(Operation *op,
2930
bool allowTensorPointerType = false);
30-
31+
LogicalResult verifyEquivalentType(Type typeA, Type typeB);
3132
LogicalResult
3233
verifySameOperandsAndResultEncoding(Operation *op,
3334
bool allowTensorPointerType = false);

include/triton/Dialect/Triton/IR/TritonInterfaces.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_INTERFACES
33

44
include "mlir/IR/OpBase.td"
5+
include "mlir/Interfaces/InferTypeOpInterface.td"
56

67
def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
78
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
@@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn
1314
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
1415
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;
1516

17+
// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
18+
// equivalence of the layouts of the result rather than just layout equality.
19+
def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
20+
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
21+
if (lhs.size() != rhs.size())
22+
return false;
23+
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
24+
auto [lhs, rhs] = tup;
25+
return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
26+
});
27+
}
28+
}]>;
29+
1630
#endif // TRITON_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
539539

540540
def TT_TransOp : TT_Op<"trans", [Pure,
541541
TransposeOpInterface,
542-
InferTypeOpAdaptorWithIsCompatible,
542+
InferTypeOpWithLayoutEquivalence,
543543
SameOperandsAndResultElementType]> {
544544

545545
let summary = "rearrange the dimensions of a tensor";

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
123123

124124
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
125125
getScratchCvtInOutVecLengths(srcTy, dstTy);
126+
// We can't write a longer vector than the shape of shared memory.
127+
// This shape might be smaller than the tensor shape in case we decided to
128+
// do the conversion in multiple iterations.
129+
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130+
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131+
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
126132

127133
// No padding is required if the tensor is 1-D, or if all dimensions except
128134
// the first accessed dimension have a size of 1.

lib/Conversion/TritonGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_triton_library(TritonGPUToLLVM
22
ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp
33
DotOpToLLVM/FMA.cpp
4+
DotOpToLLVM/FMADotUtility.cpp
45
AllocateSharedMemory.cpp
56
AssertOpToLLVM.cpp
67
ControlFlowOpToLLVM.cpp
Lines changed: 23 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,38 @@
1-
#include "mlir/Support/LLVM.h"
1+
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
22
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
3-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4-
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
53

64
using namespace mlir;
75
using namespace mlir::triton;
86
using namespace ::mlir::triton::gpu;
97

10-
using ::mlir::LLVM::linearize;
11-
using ::mlir::triton::gpu::expandMatrixOrderWithBatch;
12-
using ::mlir::triton::gpu::expandMatrixShapeWithBatch;
13-
using ::mlir::triton::gpu::getShapePerCTA;
14-
using ::mlir::triton::gpu::getSizePerThread;
15-
16-
/// \brief spatial position of repetition and register of a given value
17-
struct OperandValueKey {
18-
unsigned bRepIdx, nonKRepIdx;
19-
unsigned bIdx, nonKIdx, kIdx;
20-
21-
bool operator==(const OperandValueKey &other) const {
22-
return (bRepIdx == other.bRepIdx && nonKRepIdx == other.nonKRepIdx &&
23-
bIdx == other.bIdx && nonKIdx == other.nonKIdx &&
24-
kIdx == other.kIdx);
25-
}
26-
};
27-
28-
template <> struct std::hash<OperandValueKey> {
29-
std::size_t operator()(const OperandValueKey &k) const {
30-
return llvm::hash_combine(k.bRepIdx, k.nonKRepIdx, k.bIdx, k.nonKIdx,
31-
k.kIdx);
8+
namespace {
9+
class GenericFMAVectorMultiplier : public FMAVectorMultiplier {
10+
OpBuilder &builder;
11+
Location loc;
12+
13+
public:
14+
GenericFMAVectorMultiplier(OpBuilder &builder, Location loc)
15+
: builder(builder), loc(loc) {}
16+
17+
Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
18+
Value c) override {
19+
auto K = a.size();
20+
assert(b.size() == K);
21+
Value accum = c;
22+
for (auto [aElem, bElem] : llvm::zip(a, b))
23+
accum = builder.create<LLVM::FMulAddOp>(loc, aElem, bElem, accum);
24+
return accum;
3225
}
3326
};
3427

35-
using ValueTableFMA = std::unordered_map<OperandValueKey, Value>;
36-
37-
static ValueTableFMA getValueTableFromStructFMA(
38-
Value val, ArrayRef<unsigned> perRepShape, ArrayRef<unsigned> repetitions,
39-
unsigned kDim, unsigned nonKDim, ConversionPatternRewriter &rewriter,
40-
Location loc, ArrayRef<unsigned> inRepOrder, ArrayRef<unsigned> repOrder) {
41-
ValueTableFMA res;
42-
auto elems = unpackLLElements(loc, val, rewriter);
43-
assert(perRepShape.size() == 3);
44-
auto numElemsRep = product(perRepShape);
45-
assert(elems.size() == numElemsRep * product(repetitions));
46-
assert(kDim == 1 || kDim == 2);
47-
assert(nonKDim == 1 || nonKDim == 2);
48-
const unsigned bDim = 0;
28+
} // namespace
4929

50-
for (unsigned idx = 0; idx < elems.size(); ++idx) {
51-
auto inRepLinearIdx = idx % numElemsRep;
52-
auto repLinearIdx = idx / numElemsRep;
53-
auto inRepSpatialIdx =
54-
mlir::LLVM::delinearize(inRepLinearIdx, perRepShape, inRepOrder);
55-
auto repSpatialIdx =
56-
mlir::LLVM::delinearize(repLinearIdx, repetitions, repOrder);
57-
OperandValueKey key{repSpatialIdx[0], repSpatialIdx[nonKDim],
58-
inRepSpatialIdx[0], inRepSpatialIdx[nonKDim],
59-
inRepSpatialIdx[kDim]};
60-
res[key] = elems[idx];
61-
}
62-
return res;
63-
}
64-
65-
LogicalResult convertFMADot(triton::DotOp op, triton::DotOp::Adaptor adaptor,
30+
LogicalResult convertFMADot(DotOp op, DotOp::Adaptor adaptor,
6631
const LLVMTypeConverter *typeConverter,
6732
ConversionPatternRewriter &rewriter) {
6833
auto *ctx = rewriter.getContext();
6934
auto loc = op.getLoc();
70-
71-
auto A = op.getA();
72-
auto D = op.getResult();
73-
74-
auto aTensorTy = cast<RankedTensorType>(A.getType());
75-
auto dTensorTy = cast<RankedTensorType>(D.getType());
76-
77-
SmallVector<int64_t> aShapePerCTA =
78-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(aTensorTy)));
79-
auto dShapePerCTA =
80-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTA(dTensorTy)));
81-
82-
BlockedEncodingAttr dLayout =
83-
cast<BlockedEncodingAttr>(dTensorTy.getEncoding());
84-
// TODO process A and B operand separately
85-
auto inRepOrder = expandMatrixOrderWithBatch(dLayout.getOrder());
86-
auto repOrder = expandMatrixOrderWithBatch(dLayout.getRepOrder());
87-
auto cc = unpackLLElements(loc, adaptor.getC(), rewriter);
88-
89-
Value llA = adaptor.getA();
90-
Value llB = adaptor.getB();
91-
92-
auto sizePerThread =
93-
expandMatrixShapeWithBatch(ArrayRef(getSizePerThread(dLayout)));
94-
auto numElemsPerThread = product(sizePerThread);
95-
auto shapePerCTATile =
96-
expandMatrixShapeWithBatch(ArrayRef(getShapePerCTATile(dLayout)));
97-
98-
unsigned K = aShapePerCTA[2];
99-
100-
unsigned threadTileShape[3];
101-
unsigned repetitions[3];
102-
for (int i = 0; i < 3; ++i) {
103-
repetitions[i] =
104-
ceil(dShapePerCTA[i], static_cast<int64_t>(shapePerCTATile[i]));
105-
}
106-
107-
auto has = getValueTableFromStructFMA(
108-
llA, {sizePerThread[0], sizePerThread[1], K},
109-
{repetitions[0], repetitions[1], 1},
110-
/*kDim*/ 2, /*nonKDim*/ 1, rewriter, loc, inRepOrder, repOrder);
111-
auto hbs = getValueTableFromStructFMA(
112-
llB, {sizePerThread[0], K, sizePerThread[2]},
113-
{repetitions[0], 1, repetitions[2]},
114-
/*kDim*/ 1, /*nonKDim*/ 2, rewriter, loc, inRepOrder, repOrder);
115-
116-
SmallVector<Value> acc = cc;
117-
118-
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
119-
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
120-
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
121-
for (unsigned b = 0; b < sizePerThread[0]; ++b)
122-
for (unsigned m = 0; m < sizePerThread[1]; ++m)
123-
for (unsigned n = 0; n < sizePerThread[2]; ++n) {
124-
SmallVector<unsigned> multiDimAccumIdx = {b, m, n};
125-
unsigned linearInRepIdx =
126-
linearize(multiDimAccumIdx, sizePerThread, inRepOrder);
127-
SmallVector<unsigned> multiDimRepIdx = {bRep, mRep, nRep};
128-
unsigned linearRepIdx =
129-
linearize(multiDimRepIdx, repetitions, repOrder);
130-
unsigned linearAccumIdx =
131-
linearInRepIdx + linearRepIdx * numElemsPerThread;
132-
for (unsigned k = 0; k < K; ++k) {
133-
auto aOp = has[{bRep, mRep, b, m, k}];
134-
auto bOp = hbs[{bRep, nRep, b, n, k}];
135-
acc[linearAccumIdx] = rewriter.create<LLVM::FMulAddOp>(
136-
loc, aOp, bOp, acc[linearAccumIdx]);
137-
}
138-
}
139-
140-
auto res = packLLElements(loc, typeConverter, acc, rewriter, dTensorTy);
141-
rewriter.replaceOp(op, res);
142-
143-
return success();
35+
GenericFMAVectorMultiplier multiplier(rewriter, loc);
36+
return parametricConvertFMADot(op, adaptor, typeConverter, rewriter,
37+
multiplier);
14438
}

0 commit comments

Comments
 (0)