Skip to content

Commit 6b996cc

Browse files
committed
Cleanup
1 parent 77422cb commit 6b996cc

File tree

10 files changed

+1
-131
lines changed

10 files changed

+1
-131
lines changed

lib/Analysis/Allocation.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
6565
RankedTensorType dstTy) {
6666
Attribute srcLayout = srcTy.getEncoding();
6767
Attribute dstLayout = dstTy.getEncoding();
68-
std::cout << "- in getRepShapeForCvt\n";
6968

7069
if (!cvtNeedsSharedMemory(srcTy, dstTy)) {
7170
return {};
@@ -82,10 +81,6 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
8281
auto dstShapePerCTA = getShapePerCTA(dstTy);
8382
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
8483
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
85-
std::cout << "!!!shapePerCTA: " << srcShapePerCTA.size() << " "
86-
<< dstShapePerCTA.size() << "\n";
87-
std::cout << "!!!shapePerCTATile: " << srcShapePerCTATile.size() << " "
88-
<< dstShapePerCTATile.size() << "\n";
8984

9085
unsigned rank = dstTy.getRank();
9186
SmallVector<unsigned> repShape(rank);
@@ -112,9 +107,7 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
112107
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
113108
RankedTensorType dstTy) {
114109
// Initialize vector sizes and stride
115-
std::cout << "getRepShapeForCvt start\n";
116110
auto repShape = getRepShapeForCvt(srcTy, dstTy);
117-
std::cout << "repShape rank: " << repShape.size() << "\n";
118111
if (repShape.empty())
119112
return ScratchConfig({}, {});
120113
ScratchConfig scratchConfig(repShape, repShape);
@@ -126,24 +119,13 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
126119

127120
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
128121
scratchConfig.order = outOrd;
129-
std::cout << "inOrd: ";
130-
for (auto i : inOrd) {
131-
std::cout << i << " ";
132-
}
133-
std::cout << "rank: " << inOrd.size() << "\n";
134-
std::cout << "outOrd: ";
135-
for (auto i : outOrd) {
136-
std::cout << i << " ";
137-
}
138-
std::cout << "rank: " << outOrd.size() << "\n";
139122

140123
unsigned srcContigPerThread =
141124
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
142125
unsigned dstContigPerThread =
143126
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
144127
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
145128
// that we cannot do vectorization.
146-
std::cout << "no index issue in getUniqueContigPerThread\n";
147129
unsigned innerDim = rank - 1;
148130
scratchConfig.inVec = outOrd[0] != innerDim ? 1
149131
: inOrd[0] != innerDim ? 1
@@ -252,33 +234,27 @@ class AllocationAnalysis {
252234
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
253235
scratchAlignment);
254236
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
255-
std::cout << "getScratchValueSize from ConvertLayoutOp\n";
256237
auto srcTy = cvtLayout.getSrc().getType();
257238
auto dstTy = cvtLayout.getType();
258239
auto srcEncoding = srcTy.getEncoding();
259240
auto dstEncoding = dstTy.getEncoding();
260241
if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
261242
mlir::isa<SharedEncodingAttr>(dstEncoding)) {
262243
// Conversions from/to shared memory do not need scratch memory.
263-
std::cout << "-- ConvertLayoutOp from/to shared memory\n";
264244
return;
265245
}
266246
// ConvertLayoutOp with both input/output non-shared_layout
267247
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
268248
// also possible to realize it with other approaches in restricted
269249
// conditions, such as warp-shuffle
270-
std::cout << "-- getScratchConfigForCvt\n";
271250
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
272-
std::cout << "-- getNumScratchElements\n";
273251
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
274252
auto bytes =
275253
isa<triton::PointerType>(srcTy.getElementType())
276254
? elems * kPtrBitWidth / 8
277255
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
278256
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
279257
scratchAlignment);
280-
std::cout << "-- ConvertLayoutOp from/to non-shared memory: " << bytes
281-
<< " bytes\n";
282258
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
283259
auto value = op->getOperand(0);
284260
// only scalar requires scratch memory

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,6 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
384384
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
385385
ArrayRef<int64_t> shape) {
386386
unsigned rank = shape.size();
387-
std::cout << "!!!GPU dialect - getShapePerCTA\n";
388-
std::cout << "CTASplitNum: ";
389-
for (auto i : CTASplitNum) {
390-
std::cout << i << " ";
391-
}
392-
std::cout << "\nshape: ";
393-
for (auto i : shape) {
394-
std::cout << i << " ";
395-
}
396-
std::cout << "\n";
397387

398388
SmallVector<int64_t> shapePerCTA(rank);
399389
for (unsigned i = 0; i < rank; ++i) {

python/src/ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ void init_triton_ir(py::module &&m) {
16221622
if (haveDump) {
16231623
auto printingFlags = OpPrintingFlags();
16241624
printingFlags.elideLargeElementsAttrs(16);
1625-
// printingFlags.enableDebugInfo();
1625+
printingFlags.enableDebugInfo();
16261626
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
16271627
if (funcToDump.empty())
16281628
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
5656
LinearLayout ret = LinearLayout::empty();
5757
for (int i = 0; i < shape.size(); i++) {
5858
// Start with the most-minor dimension, which is order[0].
59-
// std::cout << "i: " << i << " shape[i]: " << shape[i]
60-
// << " order[i]: " << order[i] << std::endl;
6159
int dim = order[i];
6260
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
6361
}
@@ -291,16 +289,6 @@ LinearLayout ensureLayoutNotSmallerThan(
291289
assert(actualSize > desiredSize ||
292290
desiredSize % actualSize == 0 && "bad shape");
293291
ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName);
294-
// std::cout << "actualSize: " << actualSize << " desiredSize: " <<
295-
// desiredSize
296-
// << std::endl;
297-
// std::cout << "outDimName: " << outDimName.str() << std::endl;
298-
// std::cout << "identity1D: "
299-
// << LinearLayout::identity1D(desiredSize / actualSize, kDim,
300-
// outDimName)
301-
// .toString()
302-
// << std::endl;
303-
// std::cout << "ret: " << ret.toString() << std::endl;
304292
assert(ret.getOutDimSize(outDimName) >= desiredSize && "bad grow");
305293
}
306294
return ret;
@@ -324,12 +312,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
324312

325313
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
326314

327-
std::cout << "shape: ";
328-
for (auto s : shape) {
329-
std::cout << s << ", ";
330-
}
331-
std::cout << std::endl;
332-
333315
llvm::SmallDenseMap<StringAttr, int64_t> labeledShape;
334316
for (auto [dim, size] : llvm::zip(outDimNames, shape)) {
335317
labeledShape[dim] = size;
@@ -338,7 +320,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
338320
LinearLayout cgaLayout =
339321
ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape)
340322
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
341-
// std::cout << "\ncgaLayout: " << cgaLayout.toString() << std::endl;
342323

343324
// Calculate the shape of the ctaLayout, which is `shape` divided by the
344325
// cgaLayout's size.
@@ -347,32 +328,19 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
347328
llvm::to_vector(cgaLayout.getOutDimNames()) &&
348329
"bad layout");
349330

350-
// std::cout << "ctaShape: ";
351331
for (auto dim : ctaLayout.getOutDimNames()) {
352332
ctaShape[dim] =
353333
std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim));
354-
// std::cout << ctaShape[dim] << ", ";
355334
}
356-
// std::cout << std::endl;
357335

358-
std::cout << "ensureLayoutNotSmallerThan start" << std::endl;
359336
ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape);
360-
// std::cout << "\nctaLayout not smaller than: " << ctaLayout.toString()
361-
// << std::endl;
362-
std::cout << "ensureLayoutNotLargerThan start" << std::endl;
363337
ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape);
364-
// std::cout << "\nctaLayout not larger than: " << ctaLayout.toString()
365-
// << std::endl;
366338

367-
// std::cout << "\ncta * cga: " << (ctaLayout * cgaLayout).toString()
368-
// << std::endl;
369339
LinearLayout ret =
370340
(std::move(ctaLayout) * std::move(cgaLayout)).transposeOuts(outDimNames);
371341
for (auto dim : ret.getOutDimNames()) {
372342
assert(ret.getOutDimSize(dim) == labeledShape[dim] && "bad shape");
373343
}
374-
// std::cout << "\ncombineCtaCgaWithShape: " << ret.toString() << std::endl;
375-
std::cout << "combineCtaCgaWithShape end" << std::endl;
376344
return ret;
377345
}
378346

@@ -569,7 +537,6 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
569537
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
570538

571539
} else if (opIdx == 1) { // Operand B
572-
std::cout << "\nOperand B" << std::endl;
573540
auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp,
574541
systolicDepth);
575542
auto laneBasesB =
@@ -591,32 +558,20 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
591558
tileLayout *=
592559
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
593560
} else { // opIdx=2 -> Operand C
594-
std::cout << "\nOperand C" << std::endl;
595561
auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
596562
auto laneBasesC =
597563
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
598564
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
599565
ArrayRef(outDimNames).take_back(2));
600-
// std::cout << tileLayout.toString() << std::endl;
601566
// The per-inst layout is repeated at each repCluster.
602567
// Hence, multiply with the identity layouts starting from the
603568
// least significant dimension.
604569
dimNonK = rank - 2;
605570
dimK = rank - 1;
606571
tileLayout *= LinearLayout::identity1D(repCluster[dimK], kRegister,
607572
outDimNames[dimK]);
608-
// std::cout << (LinearLayout::identity1D(repCluster[dimK], kRegister,
609-
// outDimNames[dimK])
610-
// .toString())
611-
// << std::endl;
612-
// std::cout << (tileLayout.toString()) << std::endl;
613573
tileLayout *= LinearLayout::identity1D(repCluster[dimNonK], kRegister,
614574
outDimNames[dimNonK]);
615-
// std::cout << (LinearLayout::identity1D(repCluster[dimNonK], kRegister,
616-
// outDimNames[dimNonK])
617-
// .toString())
618-
// << std::endl;
619-
// std::cout << (tileLayout.toString()) << std::endl;
620575

621576
// // The identical layout is repeated among warps
622577
tileLayout *=
@@ -626,7 +581,6 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
626581
if (rank == 3)
627582
tileLayout *=
628583
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
629-
// std::cout << (tileLayout.toString()) << std::endl;
630584
}
631585

632586
// Lastly, the layout repeats to match the shape.
@@ -651,9 +605,6 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
651605
if (rank == 3)
652606
tileLayout *=
653607
LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
654-
// std::cout << "\ntileLayout with DPASRepetition: " <<
655-
// (tileLayout.toString())
656-
// << std::endl;
657608

658609
return combineCtaCgaWithShape(std::move(tileLayout),
659610
CTALayoutAttr::getDefault(ctx, rank), shape);

third_party/intel/lib/TritonIntelGPUToLLVM/AllocateSharedMemory.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,17 @@ struct AllocateSharedMemory
2020
AllocateSharedMemory>::IntelAllocateSharedMemoryBase;
2121

2222
void runOnOperation() override {
23-
std::cout << "AllocateSharedMemory Start\n";
2423
ModuleOp mod = getOperation();
2524
MLIRContext *ctx = &getContext();
26-
std::cout << "Before create Module Allocation\n";
2725
ModuleAllocation allocation(mod);
2826

29-
std::cout << "Before mod walk\n";
3027
mod.walk([&](FunctionOpInterface funcOp) {
3128
if (allocation.isRoot(funcOp) && allocation.getSharedMemorySize()) {
3229
LLVM::LLVMPointerType ptrTy = LLVM::LLVMPointerType::get(
3330
ctx, triton::TritonGEN::TritonGENMemorySpace::kWorkgroup);
3431
funcOp.insertArgument(funcOp.getNumArguments(), ptrTy, {},
3532
funcOp.getLoc());
3633
}
37-
std::cout << "Before funcOp walk\n";
3834
funcOp.walk([&](Operation *op) {
3935
auto *funcAllocation = allocation.getFuncData(funcOp);
4036
auto oBufferId = funcAllocation->getBufferId(op);
@@ -53,7 +49,6 @@ struct AllocateSharedMemory
5349
IntegerAttr::get(IntegerType::get(ctx, 32), offset));
5450
});
5551
});
56-
std::cout << "Before getSharedMemorySize\n";
5752
int32_t initialSharedMemorySize = 0;
5853
if (IntegerAttr sharedAttr =
5954
mod->getAttrOfType<IntegerAttr>("triton_gpu.shared"))
@@ -62,7 +57,6 @@ struct AllocateSharedMemory
6257
IntegerAttr::get(IntegerType::get(ctx, 32),
6358
initialSharedMemorySize +
6459
allocation.getSharedMemorySize()));
65-
std::cout << "AllocateSharedMemory End\n";
6660
}
6761
};
6862

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ struct ConvertLayoutOpConversion
4040
LogicalResult
4141
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
4242
ConversionPatternRewriter &rewriter) const override {
43-
std::cout << "ConvertLayoutOpConversion" << std::endl;
4443
RankedTensorType srcTy = op.getSrc().getType();
4544
RankedTensorType dstTy = op.getType();
4645
Attribute srcLayout = srcTy.getEncoding();
@@ -66,7 +65,6 @@ struct ConvertLayoutOpConversion
6665
RankedTensorType type,
6766
ArrayRef<unsigned> multiDimCTAInRepId,
6867
ArrayRef<unsigned> shapePerCTATile) const {
69-
std::cout << "getMultiDimOffset" << std::endl;
7068
auto shape = type.getShape();
7169
unsigned rank = shape.size();
7270
if (auto blockedLayout = dyn_cast<BlockedEncodingAttr>(layout)) {
@@ -143,7 +141,6 @@ struct ConvertLayoutOpConversion
143141
ArrayRef<unsigned> origRepShape,
144142
ArrayRef<unsigned> outOrd, SmallVector<Value> &vals,
145143
Value smemBase) const {
146-
std::cout << "processReplica" << std::endl;
147144
auto accumNumCTAsEachRep = product<unsigned>(numCTAsEachRep);
148145
auto layout = type.getEncoding();
149146
auto rank = type.getRank();
@@ -229,7 +226,6 @@ struct ConvertLayoutOpConversion
229226
lowerDistributedToDistributed(triton::gpu::ConvertLayoutOp op,
230227
OpAdaptor adaptor,
231228
ConversionPatternRewriter &rewriter) const {
232-
std::cout << "lowerDistributedToDistributed" << std::endl;
233229
auto loc = op.getLoc();
234230
auto typeConverter = getTypeConverter();
235231
RankedTensorType srcTy = op.getSrc().getType();
@@ -329,7 +325,6 @@ struct ConvertLayoutOpConversion
329325
ConversionPatternRewriter &rewriter,
330326
Value vals,
331327
RankedTensorType srcType) const {
332-
std::cout << "getValuesFromDpasLayoutStruct" << std::endl;
333328
SmallVector<Value> elems = unpackLLElements(loc, vals, rewriter);
334329
auto dpasLayout = dyn_cast<DpasEncodingAttr>(srcType.getEncoding());
335330

@@ -374,7 +369,6 @@ struct ConvertLayoutOpConversion
374369
Value composeValuesToDotOperandLayoutStruct(
375370
Location loc, ConversionPatternRewriter &rewriter, const ValueTable &vals,
376371
RankedTensorType dstType) const {
377-
std::cout << "composeValuesToDotOperandLayoutStruct" << std::endl;
378372
auto dotLayout = dyn_cast<DotOperandEncodingAttr>(dstType.getEncoding());
379373
auto dpasLayout = dyn_cast<DpasEncodingAttr>(dotLayout.getParent());
380374
unsigned opIdx = dotLayout.getOpIdx();
@@ -431,7 +425,6 @@ struct ConvertLayoutOpConversion
431425
LogicalResult
432426
lowerDpasToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
433427
ConversionPatternRewriter &rewriter) const {
434-
std::cout << "lowerDpasToDotOperand" << std::endl;
435428
Location loc = op.getLoc();
436429
RankedTensorType srcTy = op.getSrc().getType();
437430
RankedTensorType dstTy = op.getType();
@@ -464,7 +457,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
464457
LogicalResult
465458
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
466459
ConversionPatternRewriter &rewriter) const override {
467-
std::cout << "ConvertLayoutOpUsingLinearLayoutsConversion" << std::endl;
468460
MLIRContext *ctx = op.getContext();
469461

470462
const auto &shape = op.getType().getShape();
@@ -513,7 +505,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
513505
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
514506
const LinearLayout &dstLayout, OpAdaptor adaptor,
515507
ConversionPatternRewriter &rewriter) const {
516-
std::cout << "transferWithinThread" << std::endl;
517508
MLIRContext *ctx = op.getContext();
518509
auto loc = op.getLoc();
519510
StringAttr kRegister = str_attr("register");

0 commit comments

Comments
 (0)