Skip to content

Commit 6e2da6a

Browse files
chengjunluwhitneywhtsang
authored andcommitted
Clean the getShapePerCTATile and getShapePerCTATileForOperand interface in DPAS layout.
1 parent 1fc3915 commit 6e2da6a

File tree

5 files changed

+12
-45
lines changed

5 files changed

+12
-45
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ along the row (resp. col) dimension.
8484
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
8585
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
8686
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
87-
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
8887
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
8988
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
9089

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,6 @@ SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
147147
return res;
148148
}
149149

150-
SmallVector<unsigned>
151-
DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
152-
auto shapeC = getShapeC();
153-
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
154-
size_t rank = shapeC.size();
155-
SmallVector<unsigned> shapePerCTATile(rank);
156-
llvm::transform(
157-
llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(),
158-
[](auto entry) { return std::get<0>(entry) * std::get<1>(entry); });
159-
return shapePerCTATile;
160-
}
161-
162150
SmallVector<unsigned> DpasEncodingAttr::getRepOrder() const {
163151
llvm::report_fatal_error("NYI. DpasEncodingAttr::getRepOrder");
164152
}
@@ -174,7 +162,14 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
174162
assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout");
175163

176164
SmallVector<unsigned> elemsPerThread(rank, 1);
177-
auto shapePerCTATile = getShapePerCTATile(shape);
165+
166+
auto shapeC = getShapeC();
167+
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
168+
SmallVector<unsigned> shapePerCTATile(rank);
169+
llvm::transform(
170+
llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(),
171+
[](auto entry) { return std::get<0>(entry) * std::get<1>(entry); });
172+
178173
unsigned tilesRow =
179174
ceil<unsigned>(shape[rank - 2], shapePerCTATile[rank - 2]);
180175
unsigned tilesCol =
@@ -306,33 +301,6 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
306301
return res;
307302
}
308303

309-
SmallVector<unsigned>
310-
DpasEncodingAttr::getShapePerCTATileForOperand(ArrayRef<int64_t> shape,
311-
int kWidth, int opIdx) const {
312-
auto parentShapePerCTATile = getShapePerCTATile(shape);
313-
size_t rank = parentShapePerCTATile.size();
314-
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
315-
if (opIdx == 0) {
316-
auto shapeA = getShapeA();
317-
return (rank == 2)
318-
? SmallVector<unsigned>{parentShapePerCTATile[0], shapeA[1]}
319-
: SmallVector<unsigned>{parentShapePerCTATile[0],
320-
parentShapePerCTATile[rank - 2],
321-
shapeA[rank - 1]};
322-
}
323-
324-
if (opIdx == 1) {
325-
auto shapeB = getShapeB();
326-
return (rank == 2)
327-
? SmallVector<unsigned>{shapeB[0], parentShapePerCTATile[1]}
328-
: SmallVector<unsigned>{parentShapePerCTATile[0],
329-
shapeB[rank - 2],
330-
parentShapePerCTATile[rank - 1]};
331-
}
332-
333-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
334-
}
335-
336304
SmallVector<unsigned>
337305
DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
338306
ArrayRef<unsigned> repCluster = getRepCluster();

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ struct ConvertLayoutOpConversion
248248
SmallVector<unsigned> outNumCTAsEachRep(rank);
249249
SmallVector<unsigned> inNumCTAs(rank);
250250
SmallVector<unsigned> outNumCTAs(rank);
251-
auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape());
252-
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, shape);
251+
auto srcShapePerCTATile = getShapePerCTATile(srcLayout);
252+
auto dstShapePerCTATile = getShapePerCTATile(dstLayout);
253253
auto shapePerCTA = getShapePerCTA(srcLayout, shape);
254254

255255
for (unsigned d = 0; d < rank; ++d) {

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
3636
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
3737
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
3838
auto order = triton::gpu::getOrder(layout);
39-
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape);
39+
auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout);
4040
Value warpSize = LLVM::intel::getModuleWarpSize(rewriter, loc);
4141
Value laneId = urem(tid, warpSize);
4242
Value warpId = udiv(tid, warpSize);

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ struct ReduceOpConversion
408408
auto resultIndices = ::intel::emitIndices(loc, rewriter, targetInfo,
409409
resultLayout, resultTy, true);
410410
auto resultShape = resultTy.getShape();
411-
auto resultCTATile = getShapePerCTATile(resultLayout, resultShape);
411+
auto resultCTATile = getShapePerCTATile(resultLayout);
412412
assert(resultIndices.size() == resultElems);
413413

414414
SmallVector<Value> resultVals(resultElems);

0 commit comments

Comments
 (0)