@@ -1147,6 +1147,12 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11471147
11481148LinearLayout tensorMemoryToLinearLayout (ArrayRef<int64_t > shape,
11491149 TensorMemoryEncodingAttr encoding) {
1150+ // [Zeros in TMEM LinearLayouts]
1151+ // If there is a zero in bases rows=32,64 this means that there is
1152+ // broadcasting, i.e. the same tensor element is duplicated in different
1153+ // addressable blocks If the zero is in any other row/col (i.e. within a given
1154+ // warp-addressable tmem space) it means it is not defined
1155+
11501156 // We model packed layouts as having the rows/cols dimensions of bitwidth=16
11511157 // This means that a layout with unpacked=True is the same as one with
11521158 // unpacked=False
@@ -1182,25 +1188,26 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11821188 auto blockM = encoding.getBlockM ();
11831189 auto blockN = std::min<int32_t >(encoding.getBlockN (), shape[1 ]);
11841190 assert (blockM == 64 || blockM == 128 );
1185- LinearLayout tile;
1191+ LinearLayout tile =
1192+ LinearLayout::zeros1D (encoding.getColStride (), kCol , dims[1 ]);
11861193 if (blockM == 64 ) {
1187- tile = LinearLayout::identity1D (16 , kRow , dims[0 ]) *
1188- LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1194+ tile * = LinearLayout::identity1D (16 , kRow , dims[0 ]) *
1195+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
11891196 auto bases = tile.getBases ();
11901197 if (shape[0 ] > blockM) {
11911198 bases[kRow ].push_back ({64 , 0 });
11921199 } else if (shape[1 ] > blockN) {
11931200 bases[kRow ].push_back ({0 , blockN});
11941201 } else {
1195- // Empty. This is modelled as broadcasting, same as for TMA(fp4)
1202+ // Empty, meaning the element is not defined
11961203 bases[kRow ].push_back ({0 , 0 });
11971204 }
11981205 bases[kRow ].push_back ({16 , 0 });
11991206 bases[kRow ].push_back ({32 , 0 });
12001207 tile = LinearLayout (bases, dims);
12011208 } else {
1202- tile = LinearLayout::identity1D (blockM, kRow , dims[0 ]) *
1203- LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1209+ tile * = LinearLayout::identity1D (blockM, kRow , dims[0 ]) *
1210+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
12041211 }
12051212 auto repsM = shape[0 ] / tile.getOutDimSize (dims[0 ]);
12061213 auto repsN = shape[1 ] / tile.getOutDimSize (dims[1 ]);
@@ -1219,14 +1226,18 @@ tensorMemoryScalesToLinearLayout(ArrayRef<int64_t> shape,
12191226 auto kRow = S (" row" );
12201227 auto kCol = S (" col" );
12211228 auto dims = standardOutDimNames (ctx, 2 );
1222- // nb. this can be done with
1223- // ensureLayoutNotSmallerThan/ensureLayoutNotLargerThan but it's a bit less
1224- // clear IMO
1229+ // See [Zeros in TMEM LinearLayouts]
12251230 // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
12261231 // We choose repOrder = [0, 1]
12271232 auto tile =
12281233 LinearLayout::identity1D (std::min<int >(32 , shape[0 ]), kRow , dims[0 ]) *
1234+ // If shape[0] < 32, we have some rows undefined
1235+ LinearLayout::zeros1D (32 / std::min<int >(32 , shape[0 ]), kRow , dims[0 ]) *
1236+ // Broadcasting
1237+ LinearLayout::zeros1D (4 , kRow , dims[0 ]) *
12291238 LinearLayout::identity1D (std::min<int >(4 , shape[1 ]), kCol , dims[1 ]) *
1239+ // If shape[1] < 4, we have some cols undefined
1240+ LinearLayout::zeros1D (4 / std::min<int >(4 , shape[1 ]), kCol , dims[1 ]) *
12301241 // reps
12311242 LinearLayout::identity1D (std::max<int >(1 , shape[0 ] / 32 ), kCol , dims[0 ]) *
12321243 LinearLayout::identity1D (std::max<int >(1 , shape[1 ] / 4 ), kCol , dims[1 ]);
0 commit comments