Skip to content

Commit 4775fed

Browse files
committed
Fix AllocationShareMemory
1 parent 2dd1d03 commit 4775fed

File tree

7 files changed

+74
-43
lines changed

7 files changed

+74
-43
lines changed

lib/Analysis/Allocation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Analysis/Allocation.h"
22

33
#include <algorithm>
4+
#include <iostream>
45
#include <limits>
56
#include <numeric>
67

@@ -173,9 +174,13 @@ class AllocationAnalysis {
173174
using GraphT = DenseMap<BufferT *, DenseSet<BufferT *>>;
174175

175176
void run() {
177+
std::cout << "!!!! getValueAndSizes start\n";
176178
getValuesAndSizes();
179+
std::cout << "!!!! resolveLiveness start\n";
177180
resolveLiveness();
181+
std::cout << "!!!! computeOffsets start\n";
178182
computeOffsets();
183+
std::cout << "!!!! AllocationAnalysis end\n";
179184
}
180185

181186
/// Initializes explicitly defined shared memory values for a given operation.

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
5656
LinearLayout ret = LinearLayout::empty();
5757
for (int i = 0; i < shape.size(); i++) {
5858
// Start with the most-minor dimension, which is order[0].
59-
std::cout << "i: " << i << " shape[i]: " << shape[i]
60-
<< " order[i]: " << order[i] << std::endl;
59+
// std::cout << "i: " << i << " shape[i]: " << shape[i]
60+
// << " order[i]: " << order[i] << std::endl;
6161
int dim = order[i];
6262
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
6363
}
@@ -291,15 +291,16 @@ LinearLayout ensureLayoutNotSmallerThan(
291291
assert(actualSize > desiredSize ||
292292
desiredSize % actualSize == 0 && "bad shape");
293293
ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName);
294-
std::cout << "actualSize: " << actualSize << " desiredSize: " << desiredSize
295-
<< std::endl;
296-
std::cout << "outDimName: " << outDimName.str() << std::endl;
297-
std::cout << "identity1D: "
298-
<< LinearLayout::identity1D(desiredSize / actualSize, kDim,
299-
outDimName)
300-
.toString()
301-
<< std::endl;
302-
std::cout << "ret: " << ret.toString() << std::endl;
294+
// std::cout << "actualSize: " << actualSize << " desiredSize: " <<
295+
// desiredSize
296+
// << std::endl;
297+
// std::cout << "outDimName: " << outDimName.str() << std::endl;
298+
// std::cout << "identity1D: "
299+
// << LinearLayout::identity1D(desiredSize / actualSize, kDim,
300+
// outDimName)
301+
// .toString()
302+
// << std::endl;
303+
// std::cout << "ret: " << ret.toString() << std::endl;
303304
assert(ret.getOutDimSize(outDimName) >= desiredSize && "bad grow");
304305
}
305306
return ret;
@@ -327,8 +328,8 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
327328
for (auto s : shape) {
328329
std::cout << s << ", ";
329330
}
330-
331331
std::cout << std::endl;
332+
332333
llvm::SmallDenseMap<StringAttr, int64_t> labeledShape;
333334
for (auto [dim, size] : llvm::zip(outDimNames, shape)) {
334335
labeledShape[dim] = size;
@@ -337,7 +338,7 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
337338
LinearLayout cgaLayout =
338339
ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape)
339340
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
340-
std::cout << "\ncgaLayout: " << cgaLayout.toString() << std::endl;
341+
// std::cout << "\ncgaLayout: " << cgaLayout.toString() << std::endl;
341342

342343
// Calculate the shape of the ctaLayout, which is `shape` divided by the
343344
// cgaLayout's size.
@@ -346,29 +347,32 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
346347
llvm::to_vector(cgaLayout.getOutDimNames()) &&
347348
"bad layout");
348349

349-
std::cout << "ctaShape: ";
350+
// std::cout << "ctaShape: ";
350351
for (auto dim : ctaLayout.getOutDimNames()) {
351352
ctaShape[dim] =
352353
std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim));
353-
std::cout << ctaShape[dim] << ", ";
354+
// std::cout << ctaShape[dim] << ", ";
354355
}
355-
std::cout << std::endl;
356+
// std::cout << std::endl;
356357

358+
std::cout << "ensureLayoutNotSmallerThan start" << std::endl;
357359
ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape);
358-
std::cout << "\nctaLayout not smaller than: " << ctaLayout.toString()
359-
<< std::endl;
360+
// std::cout << "\nctaLayout not smaller than: " << ctaLayout.toString()
361+
// << std::endl;
362+
std::cout << "ensureLayoutNotLargerThan start" << std::endl;
360363
ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape);
361-
std::cout << "\nctaLayout not larger than: " << ctaLayout.toString()
362-
<< std::endl;
364+
// std::cout << "\nctaLayout not larger than: " << ctaLayout.toString()
365+
// << std::endl;
363366

364-
std::cout << "\ncta * cga: " << (ctaLayout * cgaLayout).toString()
365-
<< std::endl;
367+
// std::cout << "\ncta * cga: " << (ctaLayout * cgaLayout).toString()
368+
// << std::endl;
366369
LinearLayout ret =
367370
(std::move(ctaLayout) * std::move(cgaLayout)).transposeOuts(outDimNames);
368371
for (auto dim : ret.getOutDimNames()) {
369372
assert(ret.getOutDimSize(dim) == labeledShape[dim] && "bad shape");
370373
}
371-
std::cout << "\ncombineCtaCgaWithShape: " << ret.toString() << std::endl;
374+
// std::cout << "\ncombineCtaCgaWithShape: " << ret.toString() << std::endl;
375+
std::cout << "combineCtaCgaWithShape end" << std::endl;
372376
return ret;
373377
}
374378

@@ -593,26 +597,26 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
593597
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
594598
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
595599
ArrayRef(outDimNames).take_back(2));
596-
std::cout << tileLayout.toString() << std::endl;
600+
// std::cout << tileLayout.toString() << std::endl;
597601
// The per-inst layout is repeated at each repCluster.
598602
// Hence, multiply with the identity layouts starting from the
599603
// least significant dimension.
600604
dimNonK = rank - 2;
601605
dimK = rank - 1;
602606
tileLayout *= LinearLayout::identity1D(repCluster[dimK], kRegister,
603607
outDimNames[dimK]);
604-
std::cout << (LinearLayout::identity1D(repCluster[dimK], kRegister,
605-
outDimNames[dimK])
606-
.toString())
607-
<< std::endl;
608-
std::cout << (tileLayout.toString()) << std::endl;
608+
// std::cout << (LinearLayout::identity1D(repCluster[dimK], kRegister,
609+
// outDimNames[dimK])
610+
// .toString())
611+
// << std::endl;
612+
// std::cout << (tileLayout.toString()) << std::endl;
609613
tileLayout *= LinearLayout::identity1D(repCluster[dimNonK], kRegister,
610614
outDimNames[dimNonK]);
611-
std::cout << (LinearLayout::identity1D(repCluster[dimNonK], kRegister,
612-
outDimNames[dimNonK])
613-
.toString())
614-
<< std::endl;
615-
std::cout << (tileLayout.toString()) << std::endl;
615+
// std::cout << (LinearLayout::identity1D(repCluster[dimNonK], kRegister,
616+
// outDimNames[dimNonK])
617+
// .toString())
618+
// << std::endl;
619+
// std::cout << (tileLayout.toString()) << std::endl;
616620

617621
// // The identical layout is repeated among warps
618622
tileLayout *=
@@ -622,7 +626,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
622626
if (rank == 3)
623627
tileLayout *=
624628
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
625-
std::cout << (tileLayout.toString()) << std::endl;
629+
// std::cout << (tileLayout.toString()) << std::endl;
626630
}
627631

628632
// Lastly, the layout repeats to match the shape.
@@ -647,8 +651,9 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
647651
if (rank == 3)
648652
tileLayout *=
649653
LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
650-
std::cout << "\ntileLayout with DPASRepetition: " << (tileLayout.toString())
651-
<< std::endl;
654+
// std::cout << "\ntileLayout with DPASRepetition: " <<
655+
// (tileLayout.toString())
656+
// << std::endl;
652657

653658
return combineCtaCgaWithShape(std::move(tileLayout),
654659
CTALayoutAttr::getDefault(ctx, rank), shape);

third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
21
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
32
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
43
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
54
#include "triton/Analysis/Allocation.h"
5+
#include <iostream>
66

77
using namespace mlir;
88

@@ -20,17 +20,21 @@ struct AllocateSharedMemory
2020
AllocateSharedMemory>::IntelAllocateSharedMemoryBase;
2121

2222
void runOnOperation() override {
23+
std::cout << "AllocateSharedMemory Start\n";
2324
ModuleOp mod = getOperation();
2425
MLIRContext *ctx = &getContext();
26+
std::cout << "Before create Module Allocation\n";
2527
ModuleAllocation allocation(mod);
2628

29+
std::cout << "Before mod walk\n";
2730
mod.walk([&](FunctionOpInterface funcOp) {
2831
if (allocation.isRoot(funcOp) && allocation.getSharedMemorySize()) {
2932
LLVM::LLVMPointerType ptrTy = LLVM::LLVMPointerType::get(
3033
ctx, triton::TritonGEN::TritonGENMemorySpace::kWorkgroup);
3134
funcOp.insertArgument(funcOp.getNumArguments(), ptrTy, {},
3235
funcOp.getLoc());
3336
}
37+
std::cout << "Before funcOp walk\n";
3438
funcOp.walk([&](Operation *op) {
3539
auto *funcAllocation = allocation.getFuncData(funcOp);
3640
auto oBufferId = funcAllocation->getBufferId(op);
@@ -49,6 +53,7 @@ struct AllocateSharedMemory
4953
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
5054
});
5155
});
56+
std::cout << "Before getSharedMemorySize\n";
5257
int32_t initialSharedMemorySize = 0;
5358
if (IntegerAttr sharedAttr =
5459
mod->getAttrOfType<IntegerAttr>("triton_gpu.shared"))
@@ -57,6 +62,7 @@ struct AllocateSharedMemory
5762
IntegerAttr::get(IntegerType::get(ctx, 32),
5863
initialSharedMemorySize +
5964
allocation.getSharedMemorySize()));
65+
std::cout << "AllocateSharedMemory End\n";
6066
}
6167
};
6268

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "TargetInfo.h"
33
#include "Utility.h"
4-
#include <iostream>
54

65
#include "intel/include/Analysis/Utility.h"
76
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
@@ -112,7 +111,6 @@ struct ConvertLayoutOpConversion
112111
}
113112
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout)) {
114113
assert(rank == 2 || rank == 3);
115-
std::cout << "!!!getMultiDimOffset: dpasLayout" << std::endl;
116114
auto multiDimBase = ::intel::emitBaseIndexForLayout(
117115
loc, rewriter, targetInfo, layout, type, false);
118116
SmallVector<SmallVector<unsigned>> offsets;

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "../Utility.h"
33
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
44
#include "llvm/Support/ErrorHandling.h"
5+
#include <iostream>
56

67
using ValueTable = std::map<std::array<int, 3>, Value>;
78
using mlir::triton::gpu::getShapePerCTA;
@@ -334,7 +335,6 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
334335
SmallVector<Value> multiDimWarpId =
335336
LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order);
336337

337-
// FIXME: Using opIdx as the dimIdx will be incorrect in 3D case.
338338
unsigned rank = shape.size();
339339
unsigned dimOuter = opIdx ? (rank - 1) : (rank - 2);
340340
unsigned ceilRes =
@@ -373,6 +373,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
373373
const SharedMemoryObject &smemObj,
374374
const LLVMTypeConverter *typeConverter, Value threadId) {
375375
auto descTy = cast<MemDescType>(tensor.getType());
376+
std::cout << "!!! SharedToDotOperandDPAS::intel::convertLayout\n";
376377
switch (opIdx) {
377378
case 0:
378379
return loadOperand<0>(rewriter, loc, descTy, encoding, smemObj,

third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <iostream>
2+
13
#include "PatternTritonGPUOpToLLVM.h"
24
#include "Utility.h"
35
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
@@ -57,6 +59,7 @@ struct LocalAllocOpConversion
5759
LogicalResult
5860
matchAndRewrite(triton::gpu::LocalAllocOp op, OpAdaptor adaptor,
5961
ConversionPatternRewriter &rewriter) const override {
62+
std::cout << "LocalAllocOpConversion start\n";
6063
if (!op.isSharedMemoryAlloc())
6164
return failure();
6265
Location loc = op->getLoc();
@@ -91,6 +94,7 @@ struct LocalAllocOpConversion
9194
}
9295
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
9396
rewriter.replaceOp(op, retVal);
97+
std::cout << "LocalAllocOpConversion end\n";
9498
return success();
9599
}
96100

@@ -122,17 +126,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
122126
LogicalResult
123127
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
124128
ConversionPatternRewriter &rewriter) const override {
129+
std::cout << "LocalLoadOpConversion start\n";
125130
MemDescType srcTy = op.getSrc().getType();
126131
RankedTensorType dstTy = op.getType();
127132
Attribute srcLayout = srcTy.getEncoding();
128133
Attribute dstLayout = dstTy.getEncoding();
129134
if (isa<SharedEncodingAttr>(srcLayout) &&
130135
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131136
dstLayout)) {
137+
std::cout << "shared -> distributed\n";
132138
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
133139
rewriter);
134140
}
135141
if (isa<DotOperandEncodingAttr>(dstLayout)) {
142+
std::cout << "shared -> dot_operand\n";
136143
return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter);
137144
}
138145
return failure();
@@ -154,6 +161,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
154161

155162
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
156163
llvmElemTy, rewriter);
164+
std::cout << "!!! smemObj strides rank: " << smemObj.getStrides().size()
165+
<< "\n";
166+
157167
Value res;
158168
if (!isOuter) {
159169
res = SharedToDotOperandDPAS::intel::convertLayout(
@@ -176,6 +186,14 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
176186
auto sharedLayout =
177187
cast<SharedEncodingAttr>(op.getSrc().getType().getEncoding());
178188

189+
sharedLayout.dump();
190+
std::cout << "!!! sharedLayout order: "
191+
<< "\n";
192+
for (auto o : sharedLayout.getOrder()) {
193+
std::cout << o << " ";
194+
}
195+
std::cout << std::endl;
196+
179197
int K;
180198
if (dotLayout.getOpIdx() == 0) // $a
181199
K = op.getType().getShape()[sharedLayout.getOrder()[0]];

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
501501
RewriterBase::InsertionGuard guard(rewriter);
502502
SmallVector<Value> result;
503503
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout)) {
504-
printf("emitBaseIndexForLayoutImpl: dpasLayout\n");
505504
result = emitBaseIndexForDpasLayout(loc, rewriter, dpasLayout, type);
506505
} else if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
507506
auto parentLayout = sliceLayout.getParent();
@@ -514,7 +513,6 @@ emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
514513
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
515514
return result;
516515
} else if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
517-
printf("emitBaseIndexForLayoutImpl: DotOperandLayout\n");
518516
result = emitBaseIndexForDotOpLayout(loc, rewriter, dotLayout, type);
519517
} else {
520518
return mlir::emitBaseIndexForLayoutImpl(loc, rewriter, target, layout, type,

0 commit comments

Comments
 (0)