Skip to content

Commit 77422cb

Browse files
committed
Fix 3d ConvertLayoutToLLVM
1 parent 4775fed commit 77422cb

File tree

8 files changed

+87
-15
lines changed

8 files changed

+87
-15
lines changed

lib/Analysis/Allocation.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "triton/Dialect/Triton/IR/Dialect.h"
1515
#include "triton/Dialect/Triton/IR/Utility.h"
1616
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/SmallVector.h"
1819

1920
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
@@ -64,6 +65,7 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6465
RankedTensorType dstTy) {
6566
Attribute srcLayout = srcTy.getEncoding();
6667
Attribute dstLayout = dstTy.getEncoding();
68+
std::cout << "- in getRepShapeForCvt\n";
6769

6870
if (!cvtNeedsSharedMemory(srcTy, dstTy)) {
6971
return {};
@@ -80,6 +82,10 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
8082
auto dstShapePerCTA = getShapePerCTA(dstTy);
8183
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
8284
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
85+
std::cout << "!!!shapePerCTA: " << srcShapePerCTA.size() << " "
86+
<< dstShapePerCTA.size() << "\n";
87+
std::cout << "!!!shapePerCTATile: " << srcShapePerCTATile.size() << " "
88+
<< dstShapePerCTATile.size() << "\n";
8389

8490
unsigned rank = dstTy.getRank();
8591
SmallVector<unsigned> repShape(rank);
@@ -106,7 +112,9 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
106112
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
107113
RankedTensorType dstTy) {
108114
// Initialize vector sizes and stride
115+
std::cout << "getRepShapeForCvt start\n";
109116
auto repShape = getRepShapeForCvt(srcTy, dstTy);
117+
std::cout << "repShape rank: " << repShape.size() << "\n";
110118
if (repShape.empty())
111119
return ScratchConfig({}, {});
112120
ScratchConfig scratchConfig(repShape, repShape);
@@ -118,13 +126,24 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
118126

119127
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
120128
scratchConfig.order = outOrd;
129+
std::cout << "inOrd: ";
130+
for (auto i : inOrd) {
131+
std::cout << i << " ";
132+
}
133+
std::cout << "rank: " << inOrd.size() << "\n";
134+
std::cout << "outOrd: ";
135+
for (auto i : outOrd) {
136+
std::cout << i << " ";
137+
}
138+
std::cout << "rank: " << outOrd.size() << "\n";
121139

122140
unsigned srcContigPerThread =
123141
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
124142
unsigned dstContigPerThread =
125143
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
126144
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
127145
// that we cannot do vectorization.
146+
std::cout << "no index issue in getUniqueContigPerThread\n";
128147
unsigned innerDim = rank - 1;
129148
scratchConfig.inVec = outOrd[0] != innerDim ? 1
130149
: inOrd[0] != innerDim ? 1
@@ -174,13 +193,9 @@ class AllocationAnalysis {
174193
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
175194

176195
void run() {
177-
std::cout << "!!!! getValueAndSizes start\n";
178196
getValuesAndSizes();
179-
std::cout << "!!!! resolveLiveness start\n";
180197
resolveLiveness();
181-
std::cout << "!!!! computeOffsets start\n";
182198
computeOffsets();
183-
std::cout << "!!!! AllocationAnalysis end\n";
184199
}
185200

186201
/// Initializes explicitly defined shared memory values for a given operation.
@@ -237,27 +252,33 @@ class AllocationAnalysis {
237252
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
238253
scratchAlignment);
239254
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
255+
std::cout << "getScratchValueSize from ConvertLayoutOp\n";
240256
auto srcTy = cvtLayout.getSrc().getType();
241257
auto dstTy = cvtLayout.getType();
242258
auto srcEncoding = srcTy.getEncoding();
243259
auto dstEncoding = dstTy.getEncoding();
244260
if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
245261
mlir::isa<SharedEncodingAttr>(dstEncoding)) {
246262
// Conversions from/to shared memory do not need scratch memory.
263+
std::cout << "-- ConvertLayoutOp from/to shared memory\n";
247264
return;
248265
}
249266
// ConvertLayoutOp with both input/output non-shared_layout
250267
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
251268
// also possible to realize it with other approaches in restricted
252269
// conditions, such as warp-shuffle
270+
std::cout << "-- getScratchConfigForCvt\n";
253271
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
272+
std::cout << "-- getNumScratchElements\n";
254273
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
255274
auto bytes =
256275
isa<triton::PointerType>(srcTy.getElementType())
257276
? elems * kPtrBitWidth / 8
258277
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
259278
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
260279
scratchAlignment);
280+
std::cout << "-- ConvertLayoutOp from/to non-shared memory: " << bytes
281+
<< " bytes\n";
261282
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
262283
auto value = op->getOperand(0);
263284
// only scalar requires scratch memory

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

33
#include <cstdint>
4+
#include <iostream>
45
#include <numeric>
56

67
#include "mlir/IR/DialectImplementation.h"
@@ -383,6 +384,17 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
383384
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
384385
ArrayRef<int64_t> shape) {
385386
unsigned rank = shape.size();
387+
std::cout << "!!!GPU dialect - getShapePerCTA\n";
388+
std::cout << "CTASplitNum: ";
389+
for (auto i : CTASplitNum) {
390+
std::cout << i << " ";
391+
}
392+
std::cout << "\nshape: ";
393+
for (auto i : shape) {
394+
std::cout << i << " ";
395+
}
396+
std::cout << "\n";
397+
386398
SmallVector<int64_t> shapePerCTA(rank);
387399
for (unsigned i = 0; i < rank; ++i) {
388400
// This wrapping rule must be consistent with emitCTAOffsetForLayout

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,24 @@ unsigned DpasEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
186186
}
187187

188188
SmallVector<unsigned> DpasEncodingAttr::getCTASplitNum() const {
189-
SmallVector<unsigned> res{1, 1};
189+
size_t rank = getWarpsPerCTA().size();
190+
SmallVector<unsigned> res(rank, 1);
190191
return res;
191192
}
192193

193194
SmallVector<unsigned> DpasEncodingAttr::getCTAOrder() const {
194-
SmallVector<unsigned> res{1, 0};
195-
return res;
195+
size_t rank = getWarpsPerCTA().size();
196+
// auto res = llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
197+
// return res;
198+
if (rank == 3)
199+
return {2, 1, 0};
200+
else
201+
return {1, 0};
196202
}
197203

198204
SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
199-
SmallVector<unsigned> res{1, 1};
205+
size_t rank = getWarpsPerCTA().size();
206+
SmallVector<unsigned> res(rank, 1);
200207
return res;
201208
}
202209

@@ -370,8 +377,8 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
370377
SmallVector<unsigned> elemsPerThread(rank);
371378
if (rank == 3)
372379
elemsPerThread[0] = repetitions[0];
373-
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * repetitions[1];
374-
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * repetitions[2];
380+
elemsPerThread[rank - 2] = sizePerThread[0] * repetitions[1];
381+
elemsPerThread[rank - 1] = sizePerThread[1] * repetitions[2];
375382

376383
return elemsPerThread;
377384
};

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "TargetInfo.h"
33
#include "Utility.h"
4+
#include <iostream>
45

56
#include "intel/include/Analysis/Utility.h"
67
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
@@ -39,6 +40,7 @@ struct ConvertLayoutOpConversion
3940
LogicalResult
4041
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
4142
ConversionPatternRewriter &rewriter) const override {
43+
std::cout << "ConvertLayoutOpConversion" << std::endl;
4244
RankedTensorType srcTy = op.getSrc().getType();
4345
RankedTensorType dstTy = op.getType();
4446
Attribute srcLayout = srcTy.getEncoding();
@@ -64,6 +66,7 @@ struct ConvertLayoutOpConversion
6466
RankedTensorType type,
6567
ArrayRef<unsigned> multiDimCTAInRepId,
6668
ArrayRef<unsigned> shapePerCTATile) const {
69+
std::cout << "getMultiDimOffset" << std::endl;
6770
auto shape = type.getShape();
6871
unsigned rank = shape.size();
6972
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
@@ -140,6 +143,7 @@ struct ConvertLayoutOpConversion
140143
ArrayRef<unsigned> origRepShape,
141144
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
142145
Value smemBase) const {
146+
std::cout << "processReplica" << std::endl;
143147
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
144148
auto layout = type.getEncoding();
145149
auto rank = type.getRank();
@@ -225,6 +229,7 @@ struct ConvertLayoutOpConversion
225229
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
226230
OpAdaptor adaptor,
227231
ConversionPatternRewriter &rewriter) const {
232+
std::cout << "lowerDistributedToDistributed" << std::endl;
228233
auto loc = op.getLoc();
229234
auto typeConverter = getTypeConverter();
230235
RankedTensorType srcTy = op.getSrc().getType();
@@ -324,6 +329,7 @@ struct ConvertLayoutOpConversion
324329
ConversionPatternRewriter &rewriter,
325330
Value vals,
326331
RankedTensorType srcType) const {
332+
std::cout << "getValuesFromDpasLayoutStruct" << std::endl;
327333
SmallVector<Value> elems = unpackLLElements(loc, vals, rewriter);
328334
auto dpasLayout = dyn_cast<DpasEncodingAttr>(srcType.getEncoding());
329335

@@ -368,6 +374,7 @@ struct ConvertLayoutOpConversion
368374
Value composeValuesToDotOperandLayoutStruct(
369375
Location loc, ConversionPatternRewriter &rewriter, const ValueTable &vals,
370376
RankedTensorType dstType) const {
377+
std::cout << "composeValuesToDotOperandLayoutStruct" << std::endl;
371378
auto dotLayout = dyn_cast<DotOperandEncodingAttr>(dstType.getEncoding());
372379
auto dpasLayout = dyn_cast<DpasEncodingAttr>(dotLayout.getParent());
373380
unsigned opIdx = dotLayout.getOpIdx();
@@ -424,6 +431,7 @@ struct ConvertLayoutOpConversion
424431
LogicalResult
425432
lowerDpasToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
426433
ConversionPatternRewriter &rewriter) const {
434+
std::cout << "lowerDpasToDotOperand" << std::endl;
427435
Location loc = op.getLoc();
428436
RankedTensorType srcTy = op.getSrc().getType();
429437
RankedTensorType dstTy = op.getType();
@@ -456,6 +464,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
456464
LogicalResult
457465
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
458466
ConversionPatternRewriter &rewriter) const override {
467+
std::cout << "ConvertLayoutOpUsingLinearLayoutsConversion" << std::endl;
459468
MLIRContext *ctx = op.getContext();
460469

461470
const auto &shape = op.getType().getShape();
@@ -504,6 +513,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
504513
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
505514
const LinearLayout &dstLayout, OpAdaptor adaptor,
506515
ConversionPatternRewriter &rewriter) const {
516+
std::cout << "transferWithinThread" << std::endl;
507517
MLIRContext *ctx = op.getContext();
508518
auto loc = op.getLoc();
509519
StringAttr kRegister = str_attr("register");

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ template <unsigned opIdx>
186186
Value DpasMatmulLoader<opIdx>::loadMatrix(
187187
int repBatch, int repOuter, int repInner, const ArrayRef<Value> ptrs,
188188
LLVM::LLVMStructType structTy, Type smemTy, Value cSwizzleOffset) const {
189+
std::cout << "-- loadMatrix: repBatch: " << repBatch
190+
<< ", repOuter: " << repOuter << ", repInner: " << repInner
191+
<< std::endl;
189192
Type elemTy = structTy.getBody()[0];
190193
assert(
191194
llvm::any_of(structTy.getBody(), [&](Type ty) { return ty == elemTy; }) &&
@@ -195,7 +198,11 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(
195198
Value offsetOuter = mul(i32_val(repOuter), repNonKDimStride);
196199
Value offsetInner = mul(i32_val(repInner), repKDimStride);
197200
Value offset = add(offsetOuter, offsetInner);
198-
// offset = add(offset, offsetBatch);
201+
// FIXME: repBatchSize and
202+
if (repBatch > 0) {
203+
Value offsetBatch = mul(i32_val(repBatch), repBatchDimStride);
204+
offset = add(offset, offsetBatch);
205+
}
199206

200207
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
201208
size_t elemNum = structTy.getBody().size();
@@ -206,6 +213,7 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(
206213
llvmStruct = insert_val(structTy, llvmStruct, val, i);
207214
}
208215

216+
std::cout << "-- loadMatrix end --" << std::endl;
209217
return llvmStruct;
210218
}
211219

@@ -234,6 +242,7 @@ Value composeValuesToDotOperandLayoutStruct(
234242
Type structTy = LLVM::LLVMStructType::getLiteral(
235243
ctx, SmallVector<Type>(elems.size(), elemTy));
236244

245+
std::cout << "packLLElements: elems size: " << elems.size() << std::endl;
237246
return packLLElements(loc, typeConverter, elems, rewriter, structTy);
238247
}
239248

@@ -269,6 +278,12 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
269278
auto sharedLayout = cast<SharedEncodingAttr>(descTy.getEncoding());
270279
ArrayRef<unsigned> order = sharedLayout.getOrder();
271280

281+
std::cout << "getLoadMatrixFn: sharedLayout order: ";
282+
for (auto i : order) {
283+
std::cout << i << " ";
284+
}
285+
std::cout << std::endl;
286+
272287
// (a, b) is the coordinate.
273288
auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int batch, int outer,
274289
int inner) {
@@ -353,6 +368,9 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
353368
int64_t numRepOuter = numReps[opIdx ? 2 : 1];
354369
int64_t numRepK = numReps[opIdx ? 1 : 2];
355370

371+
std::cout << "!!! numRepBatch: " << numRepBatch
372+
<< ", numRepOuter: " << numRepOuter << ", numRepK: " << numRepK
373+
<< "\n";
356374
for (int b = 0; b < numRepBatch; ++b)
357375
for (int m = 0; m < numRepOuter; ++m)
358376
for (int k = 0; k < numRepK; ++k)

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "../TritonGPUToLLVMBase.h"
22
#include "../Utility.h"
33
#include "mlir/IR/BuiltinTypes.h"
4+
#include <iostream>
45

56
#include "intel/include/Analysis/DPAS.h"
67
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
@@ -176,6 +177,9 @@ class DotOpDPASConversionHelper {
176177
});
177178

178179
auto generateDPASOp = [&](unsigned b, unsigned m, unsigned n, unsigned k) {
180+
std::cout << "valA: " << b << " " << m << " " << k << "\n";
181+
std::cout << "valB: " << b << " " << n << " " << k << "\n";
182+
std::cout << "valC: " << b << " " << m << " " << n << "\n";
179183
Value valA = ha.at({b, m, k});
180184
Value valB = hb.at({b, n, k});
181185
Value valc = fc.at({b, m, n});
@@ -186,7 +190,7 @@ class DotOpDPASConversionHelper {
186190
TritonGEN::PrecisionTypeAttr::get(B.getContext(), BPrecision);
187191
auto RC = IntegerAttr::get(rewriter.getIntegerType(32),
188192
dpasEncoding.getRepeatCount());
189-
fc.at({m, n}) = rewriter.create<TritonGEN::MatrixDPASOp>(
193+
fc.at({b, m, n}) = rewriter.create<TritonGEN::MatrixDPASOp>(
190194
loc, dTy, valc, valA, valB, pA, pB, RC);
191195
};
192196

third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
186186
auto sharedLayout =
187187
cast<SharedEncodingAttr>(op.getSrc().getType().getEncoding());
188188

189-
sharedLayout.dump();
190-
std::cout << "!!! sharedLayout order: "
191-
<< "\n";
189+
std::cout << "!!! sharedLayout order: ";
192190
for (auto o : sharedLayout.getOrder()) {
193191
std::cout << o << " ";
194192
}

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
1717
#include "triton/Dialect/Triton/IR/Utility.h"
1818
#include "llvm/Support/ErrorHandling.h"
19+
#include <iostream>
1920

2021
#define DEBUG_TYPE "ttgpu_to_llvm"
2122

@@ -573,6 +574,7 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
573574
inline SmallVector<SmallVector<Value>>
574575
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
575576
Attribute layout, RankedTensorType type, bool withCTAOffset) {
577+
std::cout << "emitIndices" << std::endl;
576578
MLIRContext *ctx = rewriter.getContext();
577579
auto shape = type.getShape();
578580
std::optional<LinearLayout> ll = triton::gpu::toLinearLayout(shape, layout);

0 commit comments

Comments
 (0)