@@ -532,10 +532,6 @@ We call each individual tile "rep".
532532 InterfaceMethod<"Get the shape of the values per thread.",
533533 "SmallVector<unsigned>",
534534 "getSizePerThread">,
535-
536- InterfaceMethod<"Gets the number of contiguous elements per thread.",
537- "SmallVector<unsigned>",
538- "getContigPerThread">,
539535 InterfaceMethod<"Convert to LinearLayout.",
540536 "LinearLayout",
541537 "toLinearLayout",
819815 }]>
820816 ];
821817
822- let extraClassDeclaration = extraDistributedDeclaration # [{
823- SmallVector<unsigned> getContigPerThread() {
824- // Block encoding is dense stride layout. The elements per thread are contiguous.
825- return getSizePerThread();
826- };
827- }];
818+ let extraClassDeclaration = extraDistributedDeclaration;
828819
829820 let hasCustomAssemblyFormat = 1;
830821}
@@ -972,17 +963,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
972963 SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
973964 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
974965 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
975-
976- SmallVector<unsigned> getContigPerThread() {
977- auto rank = getWarpsPerCTA().size();
978- SmallVector<unsigned> contigPerThread(rank, 1);
979- if (getIsTransposed())
980- contigPerThread[rank - 1] = 4;
981- else
982- contigPerThread[rank - 2] = 4;
983- return contigPerThread;
984- };
985-
986966 }];
987967
988968 let genVerifyDecl = 1;
@@ -1100,16 +1080,6 @@ Row |
11001080 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11011081 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
11021082 static SmallVector<unsigned> getMNKDimPerInstr();
1103-
1104- SmallVector<unsigned> getContigPerThread() {
1105- auto rank = getWarpsPerCTA().size();
1106- assert(rank == 2 || rank == 3);
1107- SmallVector<unsigned> contigPerThread(rank, 1);
1108- if (getVersion() == 2) {
1109- contigPerThread[rank - 2] = 8;
1110- }
1111- return contigPerThread;
1112- };
11131083 }];
11141084}
11151085
@@ -1219,15 +1189,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12191189 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12201190 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
12211191 SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1222-
1223- SmallVector<unsigned> getContigPerThread() {
1224- assert(isAmpere() || isHopper());
1225- auto rank = getWarpsPerCTA().size();
1226- SmallVector<unsigned> contigPerThread(rank, 1);
1227- contigPerThread[rank - 1] = 2;
1228- return contigPerThread;
1229- };
1230-
12311192 }];
12321193
12331194 let hasCustomAssemblyFormat = 1;
@@ -1274,13 +1235,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12741235 let extraClassDeclaration = extraDistributedDeclaration # [{
12751236 template<class T>
12761237 SmallVector<T> paddedShape(ArrayRef<T> shape) const;
1277-
1278- SmallVector<unsigned> getContigPerThread() {
1279- auto parentLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1280- auto parentContigPerThread = parentLayout.getContigPerThread();
1281- parentContigPerThread.erase(parentContigPerThread.begin() + getDim());
1282- return parentContigPerThread;
1283- };
12841238 }];
12851239
12861240 let hasCustomAssemblyFormat = 1;
@@ -1348,20 +1302,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
13481302
13491303 let assemblyFormat = "`<` `{` struct(params) `}` `>`";
13501304 let genVerifyDecl = 1;
1351- let extraClassDeclaration = extraDistributedDeclaration # [{
1352- SmallVector<unsigned> getContigPerThread() {
1353- auto rank = getWarpsPerCTA().size();
1354- assert(rank == 2 || rank == 3);
1355- SmallVector<unsigned> contigPerThread(rank, 1);
1356- auto kWidth = getKWidth();
1357- assert(kWidth != 0 && "Do not support kWidth=0");
1358- if (getOpIdx() == 0)
1359- contigPerThread[rank - 1] = kWidth;
1360- else
1361- contigPerThread[rank - 2] = kWidth;
1362- return contigPerThread;
1363- };
1364- }];
1305+ let extraClassDeclaration = extraDistributedDeclaration;
13651306}
13661307
13671308def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {
0 commit comments