@@ -532,10 +532,6 @@ We call each individual tile "rep".
532532 InterfaceMethod<"Get the shape of the values per thread.",
533533 "SmallVector<unsigned>",
534534 "getSizePerThread">,
535-
536- InterfaceMethod<"Gets the number of contiguous elements per thread.",
537- "SmallVector<unsigned>",
538- "getContigPerThread">,
539535 InterfaceMethod<"Convert to LinearLayout.",
540536 "LinearLayout",
541537 "toLinearLayout",
819815 }]>
820816 ];
821817
822- let extraClassDeclaration = extraDistributedDeclaration # [{
823- SmallVector<unsigned> getContigPerThread() {
824- // Block encoding is dense stride layout. The elements per thread are contiguous.
825- return getSizePerThread();
826- };
827- }];
818+ let extraClassDeclaration = extraDistributedDeclaration;
828819
829820 let hasCustomAssemblyFormat = 1;
830821}
@@ -972,17 +963,6 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
972963 SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
973964 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
974965 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
975-
976- SmallVector<unsigned> getContigPerThread() {
977- auto rank = getWarpsPerCTA().size();
978- SmallVector<unsigned> contigPerThread(rank, 1);
979- if (getIsTransposed())
980- contigPerThread[rank - 1] = 4;
981- else
982- contigPerThread[rank - 2] = 4;
983- return contigPerThread;
984- };
985-
986966 }];
987967
988968 let genVerifyDecl = 1;
@@ -1100,16 +1080,6 @@ Row |
11001080 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11011081 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
11021082 static SmallVector<unsigned> getMNKDimPerInstr();
1103-
1104- SmallVector<unsigned> getContigPerThread() {
1105- auto rank = getWarpsPerCTA().size();
1106- assert(rank == 2 || rank == 3);
1107- SmallVector<unsigned> contigPerThread(rank, 1);
1108- if (getVersion() == 2) {
1109- contigPerThread[rank - 2] = 8;
1110- }
1111- return contigPerThread;
1112- };
11131083 }];
11141084}
11151085
@@ -1219,15 +1189,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12191189 SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
12201190 SmallVector<unsigned> getThreadsPerWarpForOperand(int opIdx) const;
12211191 SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
1222-
1223- SmallVector<unsigned> getContigPerThread() {
1224- assert(isAmpere() || isHopper());
1225- auto rank = getWarpsPerCTA().size();
1226- SmallVector<unsigned> contigPerThread(rank, 1);
1227- contigPerThread[rank - 1] = 2;
1228- return contigPerThread;
1229- };
1230-
12311192 }];
12321193
12331194 let hasCustomAssemblyFormat = 1;
@@ -1273,13 +1234,6 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
12731234 let extraClassDeclaration = extraDistributedDeclaration # [{
12741235 template<class T>
12751236 SmallVector<T> paddedShape(ArrayRef<T> shape) const;
1276-
1277- SmallVector<unsigned> getContigPerThread() {
1278- auto parentLayout = mlir::cast<DistributedEncodingTrait>(getParent());
1279- auto parentContigPerThread = parentLayout.getContigPerThread();
1280- parentContigPerThread.erase(parentContigPerThread.begin() + getDim());
1281- return parentContigPerThread;
1282- };
12831237 }];
12841238
12851239 let hasCustomAssemblyFormat = 1;
@@ -1347,20 +1301,7 @@ vecIdx (index of the element in the quad; this is always along the k-dim)
13471301
13481302 let assemblyFormat = "`<` `{` struct(params) `}` `>`";
13491303 let genVerifyDecl = 1;
1350- let extraClassDeclaration = extraDistributedDeclaration # [{
1351- SmallVector<unsigned> getContigPerThread() {
1352- auto rank = getWarpsPerCTA().size();
1353- assert(rank == 2 || rank == 3);
1354- SmallVector<unsigned> contigPerThread(rank, 1);
1355- auto kWidth = getKWidth();
1356- assert(kWidth != 0 && "Do not support kWidth=0");
1357- if (getOpIdx() == 0)
1358- contigPerThread[rank - 1] = kWidth;
1359- else
1360- contigPerThread[rank - 2] = kWidth;
1361- return contigPerThread;
1362- };
1363- }];
1304+ let extraClassDeclaration = extraDistributedDeclaration;
13641305}
13651306
13661307def TTG_SharedMemorySpace : AttrDef<TritonGPU_Dialect, "SharedMemorySpace"> {
0 commit comments