Skip to content

Commit af334a2

Browse files
committed
conversion tritonGPUToLLVM utility decoupling (has a bug)
1 parent 7e0dae7 commit af334a2

File tree

7 files changed

+259
-28
lines changed

7 files changed

+259
-28
lines changed

third_party/iluvatar/backend/flagtree_backend_specialization/include/flagtree_spec.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
#include "triton/Analysis/iluvatar_Utility.h"
44
#include "triton/Conversion/TritonGPUToLLVM/iluvatar_ElementwiseOpToLLVMBase.h"
55
#include "triton/Conversion/TritonGPUToLLVM/iluvatar_TargetInfoBase.h"
6+
#include "triton/Conversion/TritonGPUToLLVM/iluvatar_Utility.h"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
2+
#define ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
3+
4+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_heads
5+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_functionPtr
6+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_createIndexConstant
7+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getMultiDimOffset_ARG bool
8+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_IluvatarMmaEncodingAttr
9+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_emitBaseIndexForLayoutImpl
10+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_emitOffsetForLayout
11+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getSwizzledSharedPtrs
12+
#define FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_storeDistributedToShared
13+
14+
#endif // ILUVATAR_TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
add_subdirectory(Analysis)
1+
add_subdirectory(Analysis)
2+
add_subdirectory(Conversion/TritonGPUToLLVM)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_triton_library(FlagTree_iluvatar_TritonConversionTritonGPUToLLVM
2+
Utility.cpp
3+
4+
DEPENDS
5+
TritonTableGen
6+
TritonGPUAttrDefsIncGen
7+
)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2+
#include "triton/../../lib/Conversion/TritonGPUToLLVM/Utility.cpp"
3+
4+
namespace mlir {
5+
namespace LLVM {
6+
7+
Value createIndexConstant(OpBuilder &builder, Location loc,
8+
TypeConverter *converter, int64_t value) {
9+
Type ty = converter->convertType(builder.getIndexType());
10+
return builder.create<LLVM::ConstantOp>(loc, ty,
11+
builder.getIntegerAttr(ty, value));
12+
}
13+
14+
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
15+
ConversionPatternRewriter &rewriter,
16+
const TargetInfoBase &targetInfo,
17+
unsigned elemId, RankedTensorType type,
18+
ArrayRef<unsigned> multiDimCTAInRepId,
19+
ArrayRef<unsigned> shapePerCTATile,
20+
bool isTrans, bool stNotRd) {
21+
auto shape = type.getShape();
22+
unsigned rank = shape.size();
23+
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
24+
auto multiDimOffsetFirstElem = emitBaseIndexForLayout(
25+
loc, rewriter, targetInfo, blockedLayout, type, false);
26+
SmallVector<Value> multiDimOffset(rank);
27+
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
28+
elemId, getSizePerThread(layout), getOrder(layout));
29+
for (unsigned d = 0; d < rank; ++d) {
30+
multiDimOffset[d] =
31+
add(multiDimOffsetFirstElem[d],
32+
i32_val(multiDimCTAInRepId[d] * shapePerCTATile[d] +
33+
multiDimElemId[d]));
34+
}
35+
return multiDimOffset;
36+
}
37+
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
38+
unsigned dim = sliceLayout.getDim();
39+
auto parentEncoding = sliceLayout.getParent();
40+
auto parentSizePerThread = getSizePerThread(parentEncoding);
41+
auto parentShape = sliceLayout.paddedShape(shape);
42+
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
43+
parentEncoding);
44+
auto offsets = emitOffsetForLayout(layout, type);
45+
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
46+
SmallVector<int> idxs;
47+
for (SmallVector<unsigned> off : offsets) {
48+
off.insert(off.begin() + dim, 0);
49+
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
50+
idxs.push_back(std::distance(parentOffset.begin(), it));
51+
}
52+
auto multiDimOffsetParent = getMultiDimOffset(
53+
parentEncoding, loc, rewriter, targetInfo, idxs[elemId], parentTy,
54+
sliceLayout.paddedShape(multiDimCTAInRepId),
55+
sliceLayout.paddedShape(shapePerCTATile));
56+
SmallVector<Value> multiDimOffset(rank);
57+
for (unsigned d = 0; d < rank + 1; ++d) {
58+
if (d == dim)
59+
continue;
60+
unsigned slicedD = d < dim ? d : (d - 1);
61+
multiDimOffset[slicedD] = multiDimOffsetParent[d];
62+
}
63+
return multiDimOffset;
64+
}
65+
if (auto mmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
66+
assert(rank == 2 ||
67+
(rank == 3 && mmaLayout.isAmpere()) && "Unexpected rank");
68+
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);
69+
auto instrShape = mmaLayout.getInstrShape();
70+
SmallVector<Value> mmaColIdx(2);
71+
SmallVector<Value> mmaRowIdx(2);
72+
Value threadId = getThreadId(rewriter, loc);
73+
Value warpSize = i32_val(32);
74+
Value laneId = urem(threadId, warpSize);
75+
Value warpId = udiv(threadId, warpSize);
76+
// TODO: fix the bug in MMAEncodingAttr document
77+
SmallVector<Value> multiDimWarpId(2);
78+
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
79+
auto warpOrder = triton::gpu::getWarpOrder(mmaLayout);
80+
multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder);
81+
Value _1 = i32_val(1);
82+
Value _2 = i32_val(2);
83+
Value _4 = i32_val(4);
84+
Value _8 = i32_val(8);
85+
Value _16 = i32_val(16);
86+
if (mmaLayout.isAmpere() || mmaLayout.isHopper()) {
87+
multiDimWarpId[rank - 1] = urem(
88+
multiDimWarpId[rank - 1],
89+
i32_val(ceil<unsigned>(shapePerCTA[rank - 1], instrShape[rank - 1])));
90+
multiDimWarpId[rank - 2] = urem(
91+
multiDimWarpId[rank - 2],
92+
i32_val(ceil<unsigned>(shapePerCTA[rank - 2], instrShape[rank - 2])));
93+
94+
Value mmaGrpId = udiv(laneId, _4);
95+
Value mmaGrpIdP8 = add(mmaGrpId, _8);
96+
Value mmaThreadIdInGrp = urem(laneId, _4);
97+
Value mmaThreadIdInGrpM2 = mul(mmaThreadIdInGrp, _2);
98+
Value mmaThreadIdInGrpM2P1 = add(mmaThreadIdInGrpM2, _1);
99+
Value rowWarpOffset =
100+
mul(multiDimWarpId[rank - 2], i32_val(instrShape[rank - 2]));
101+
mmaRowIdx[0] = add(mmaGrpId, rowWarpOffset);
102+
mmaRowIdx[1] = add(mmaGrpIdP8, rowWarpOffset);
103+
Value colWarpOffset =
104+
mul(multiDimWarpId[rank - 1], i32_val(instrShape[rank - 1]));
105+
mmaColIdx[0] = add(mmaThreadIdInGrpM2, colWarpOffset);
106+
mmaColIdx[1] = add(mmaThreadIdInGrpM2P1, colWarpOffset);
107+
} else if (mmaLayout.isVolta()) {
108+
// Volta doesn't follow the pattern here.
109+
} else {
110+
llvm_unreachable("Unexpected MMALayout version");
111+
}
112+
113+
SmallVector<Value> multiDimOffset(rank);
114+
if (mmaLayout.isHopper()) {
115+
unsigned elemIdRem4 = elemId % 4;
116+
unsigned nGrpId = elemId / 4;
117+
multiDimOffset[0] = elemIdRem4 < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
118+
multiDimOffset[1] = elemIdRem4 % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
119+
multiDimOffset[1] = add(multiDimOffset[1], i32_val(8 * nGrpId));
120+
multiDimOffset[0] = add(multiDimOffset[0], i32_val(multiDimCTAInRepId[0] *
121+
shapePerCTATile[0]));
122+
multiDimOffset[1] = add(multiDimOffset[1], i32_val(multiDimCTAInRepId[1] *
123+
shapePerCTATile[1]));
124+
} else if (mmaLayout.isAmpere()) {
125+
if (rank == 3)
126+
multiDimOffset[0] =
127+
add(multiDimWarpId[0],
128+
i32_val(multiDimCTAInRepId[0] * shapePerCTATile[0]));
129+
multiDimOffset[rank - 2] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
130+
multiDimOffset[rank - 1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
131+
multiDimOffset[rank - 2] =
132+
add(multiDimOffset[rank - 2], i32_val(multiDimCTAInRepId[rank - 2] *
133+
shapePerCTATile[rank - 2]));
134+
multiDimOffset[rank - 1] =
135+
add(multiDimOffset[rank - 1], i32_val(multiDimCTAInRepId[rank - 1] *
136+
shapePerCTATile[rank - 1]));
137+
} else if (mmaLayout.isVolta()) {
138+
auto [isARow, isBRow, isAVec4, isBVec4, _] =
139+
mmaLayout.decodeVoltaLayoutStates();
140+
auto coords = SharedToDotOperandMMAv1::getMNCoords(
141+
threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
142+
isARow, isBRow, isAVec4, isBVec4);
143+
return coords[elemId];
144+
} else {
145+
llvm_unreachable("Unexpected MMALayout version");
146+
}
147+
return multiDimOffset;
148+
}
149+
if (auto mmaLayout = mlir::dyn_cast<IluvatarMmaEncodingAttr>(layout)) {
150+
assert(rank == 2 && "Unexpected rank");
151+
SmallVector<Value> multiDimOffset(rank);
152+
Value threadId = getThreadId(rewriter, loc);
153+
if (mmaLayout.isVolta()) {
154+
int bitwidth = type.getElementType().getIntOrFloatBitWidth();
155+
int elemVecSize = stNotRd ? (32 / bitwidth) : 1;
156+
static auto func = SharedToDotOperandMMAv1::load_getMNCoords_func(
157+
"iluvatar", "getMNCoords");
158+
auto coords = func(threadId, loc, rewriter, mmaLayout.getWarpsPerCTA(),
159+
mmaLayout, shape, bitwidth, elemVecSize, isTrans);
160+
return coords[elemId];
161+
} else {
162+
llvm_unreachable("Unexpected MMALayout version");
163+
}
164+
}
165+
if (isa<AMDMfmaEncodingAttr, AMDWmmaEncodingAttr>(layout)) {
166+
auto multiDimBase =
167+
emitBaseIndexForLayout(loc, rewriter, targetInfo, layout, type, false);
168+
SmallVector<SmallVector<unsigned>> offsets;
169+
assert(rank == 2);
170+
SmallVector<Value> multiDimOffset(rank);
171+
if (auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout)) {
172+
emitMfmaOffsetForCTA(mfmaLayout, offsets, 0, multiDimCTAInRepId[0],
173+
multiDimCTAInRepId[1]);
174+
} else if (auto wmmaLayout = dyn_cast<AMDWmmaEncodingAttr>(layout)) {
175+
emitWmmaOffsetForCTA(wmmaLayout, offsets, 0, multiDimCTAInRepId[0],
176+
multiDimCTAInRepId[1]);
177+
}
178+
multiDimOffset[0] = add(multiDimBase[0], i32_val(offsets[elemId][0]));
179+
multiDimOffset[1] = add(multiDimBase[1], i32_val(offsets[elemId][1]));
180+
return multiDimOffset;
181+
}
182+
llvm_unreachable("unexpected layout in getMultiDimOffset");
183+
}
184+
185+
} // namespace LLVM
186+
} // namespace mlir

third_party/iluvatar/include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,36 @@
66
#include "mlir/Conversion/LLVMCommon/Pattern.h"
77
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
88
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9-
#include "python/src/plugin.h"
109
#include "triton/Analysis/Utility.h"
1110
#include "triton/Conversion/MLIRTypes.h"
1211
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
1312
#include "triton/Dialect/Triton/IR/Utility.h"
1413
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1514
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
15+
16+
#ifndef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_heads
17+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
18+
#else
19+
#include "python/src/plugin.h"
1620
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
21+
#endif
22+
1723
#include "triton/Tools/LinearLayout.h"
1824
#include "triton/Tools/StrUtil.h"
1925
#include "triton/Tools/Sys/GetEnv.hpp"
2026
#include "llvm/ADT/STLExtras.h"
2127
#include "llvm/Support/ErrorHandling.h"
2228

29+
#include "triton/../../backend/flagtree_backend_specialization/include/flagtree_spec.h"
30+
2331
#define DEBUG_TYPE "ttgpu_to_llvm"
2432
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
2533
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2634

2735
using namespace mlir;
2836
using namespace mlir::triton;
2937

38+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_functionPtr
3039
using emitOffsetForTCULayoutFunc = SmallVector<SmallVector<unsigned>> (*)(
3140
const triton::gpu::IluvatarMmaEncodingAttr &, RankedTensorType);
3241
DEFINE_LOAD_FUNC(emitOffsetForTCULayout)
@@ -39,6 +48,7 @@ DEFINE_LOAD_FUNC(emitBaseIndexForTCULayout)
3948
using remapOffsetFunc = Value (*)(Value, Value, RankedTensorType, bool,
4049
Location, RewriterBase &, int, bool);
4150
DEFINE_LOAD_FUNC(remapOffset)
51+
#endif
4252

4353
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
4454
// Operators
@@ -245,8 +255,13 @@ Value createConstantF64(Location loc, OpBuilder &rewriter, double v);
245255
Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type);
246256

247257
/// Create an index type constant.
258+
#ifndef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_createIndexConstant
259+
Value createIndexConstant(OpBuilder &builder, Location loc,
260+
const TypeConverter *converter, int64_t value);
261+
#else
248262
Value createIndexConstant(OpBuilder &builder, Location loc,
249263
TypeConverter *converter, int64_t value);
264+
#endif
250265

251266
/// Create an integer constant of \param width bits.
252267
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
@@ -359,11 +374,23 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
359374
// the smem buffer. Recall that the smem buffer will only store a single replica
360375
// when converting distributed to distributed layout. Also, a replica is the
361376
// smallest CTA tile that is common between input and output layouts.
362-
SmallVector<Value> getMultiDimOffset(
363-
Attribute layout, Location loc, ConversionPatternRewriter &rewriter,
364-
const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type,
365-
ArrayRef<unsigned> multiDimCTAInRepId, ArrayRef<unsigned> shapePerCTATile,
366-
bool isTrans = false, bool stNotRd = false);
377+
#ifndef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getMultiDimOffset_ARG
378+
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
379+
ConversionPatternRewriter &rewriter,
380+
const TargetInfoBase &targetInfo,
381+
unsigned elemId, RankedTensorType type,
382+
ArrayRef<unsigned> multiDimCTAInRepId,
383+
ArrayRef<unsigned> shapePerCTATile);
384+
#else
385+
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
386+
ConversionPatternRewriter &rewriter,
387+
const TargetInfoBase &targetInfo,
388+
unsigned elemId, RankedTensorType type,
389+
ArrayRef<unsigned> multiDimCTAInRepId,
390+
ArrayRef<unsigned> shapePerCTATile,
391+
FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getMultiDimOffset_ARG spec_arg1 = false,
392+
FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getMultiDimOffset_ARG spec_arg2 = false);
393+
#endif
367394

368395
// Given a multiDimOffset, this function wraps around each dimension to be
369396
// within shape.
@@ -434,7 +461,11 @@ using ::mlir::triton::gpu::AMDWmmaEncodingAttr;
434461
using ::mlir::triton::gpu::BlockedEncodingAttr;
435462
using ::mlir::triton::gpu::CTALayoutAttr;
436463
using ::mlir::triton::gpu::DotOperandEncodingAttr;
464+
465+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_IluvatarMmaEncodingAttr
437466
using ::mlir::triton::gpu::IluvatarMmaEncodingAttr;
467+
#endif
468+
438469
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
439470
using ::mlir::triton::gpu::SliceEncodingAttr;
440471

@@ -1128,11 +1159,13 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
11281159
if (mmaLayout.isAmpere() || mmaLayout.isHopper())
11291160
result = emitBaseIndexWithinCTAForMmaLayoutV2V3(loc, rewriter, mmaLayout,
11301161
type);
1162+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_emitBaseIndexForLayoutImpl
11311163
} else if (auto mmaLayout = mlir::dyn_cast<IluvatarMmaEncodingAttr>(layout)) {
11321164
if (mmaLayout.isVolta()) {
11331165
DEFINE_CALL_LOAD_FUNC(iluvatar, emitBaseIndexForTCULayout)
11341166
result = func(loc, rewriter, mmaLayout, type);
11351167
}
1168+
#endif
11361169
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(layout)) {
11371170
result = emitBaseIndexForMfmaLayout(loc, rewriter, mfmaLayout, type);
11381171
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(layout)) {
@@ -1201,12 +1234,14 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
12011234
if (mmaLayout.isHopper())
12021235
return emitOffsetForMmaLayoutV3(mmaLayout, type);
12031236
}
1237+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_emitOffsetForLayout
12041238
if (auto mmaLayout = dyn_cast<IluvatarMmaEncodingAttr>(layout)) {
12051239
if (mmaLayout.isVolta()) {
12061240
DEFINE_CALL_LOAD_FUNC(iluvatar, emitOffsetForTCULayout)
12071241
return func(mmaLayout, type);
12081242
}
12091243
}
1244+
#endif
12101245
if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(layout)) {
12111246
return emitOffsetForMfmaLayout(mfmaLayout, type);
12121247
}
@@ -1362,7 +1397,7 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
13621397
}
13631398
// compute phase = (row // perPhase) % maxPhase
13641399
Value phase = urem(udiv(idxRow, i32_val(perPhase)), i32_val(maxPhase));
1365-
#if defined(__ILUVATAR__)
1400+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_getSwizzledSharedPtrs
13661401
// corex swizzle
13671402
bool isRow = outOrder[0] == 1;
13681403
Value off = NULL;
@@ -1524,7 +1559,7 @@ inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
15241559
// If the shmem layout is not swizzled, we can trivially vectorize stores
15251560
// across the whole width of the most-minor dimension of the shape, because
15261561
// Triton requires all the dims are powers of 2.
1527-
#ifdef __ILUVATAR__
1562+
#ifdef FLAGTREE_SPEC_Conversion_TritonGPUToLLVM_Utility_storeDistributedToShared
15281563
unsigned outVec = dstSharedLayout.getVec();
15291564
#else
15301565
unsigned outVec = dstSharedLayout.getMaxPhase() == 1

0 commit comments

Comments
 (0)