Skip to content

Commit 4f7a8b8

Browse files
authored
[BACKEND] Represent broadcasting in TensorMemoryLayouts (#8148)
Follow up of triton-lang/triton#8136. We now have a faithful representation of unpacked linear layouts. Using this, we are able to remove several hacks that we used in the`tcgen05.ld/st` lowering and generally make it more robust.
1 parent aff4b7a commit 4f7a8b8

File tree

2 files changed

+138
-128
lines changed

2 files changed

+138
-128
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,12 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11471147

11481148
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11491149
TensorMemoryEncodingAttr encoding) {
1150+
// [Zeros in TMEM LinearLayouts]
1151+
// If there is a zero in bases rows=32,64 this means that there is
1152+
// broadcasting, i.e. the same tensor element is duplicated in different
1153+
// addressable blocks If the zero is in any other row/col (i.e. within a given
1154+
// warp-addressable tmem space) it means it is not defined
1155+
11501156
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
11511157
// This means that a layout with unpacked=True is the same as one with
11521158
// unpacked=False
@@ -1182,25 +1188,26 @@ LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
11821188
auto blockM = encoding.getBlockM();
11831189
auto blockN = std::min<int32_t>(encoding.getBlockN(), shape[1]);
11841190
assert(blockM == 64 || blockM == 128);
1185-
LinearLayout tile;
1191+
LinearLayout tile =
1192+
LinearLayout::zeros1D(encoding.getColStride(), kCol, dims[1]);
11861193
if (blockM == 64) {
1187-
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
1188-
LinearLayout::identity1D(blockN, kCol, dims[1]);
1194+
tile *= LinearLayout::identity1D(16, kRow, dims[0]) *
1195+
LinearLayout::identity1D(blockN, kCol, dims[1]);
11891196
auto bases = tile.getBases();
11901197
if (shape[0] > blockM) {
11911198
bases[kRow].push_back({64, 0});
11921199
} else if (shape[1] > blockN) {
11931200
bases[kRow].push_back({0, blockN});
11941201
} else {
1195-
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
1202+
// Empty, meaning the element is not defined
11961203
bases[kRow].push_back({0, 0});
11971204
}
11981205
bases[kRow].push_back({16, 0});
11991206
bases[kRow].push_back({32, 0});
12001207
tile = LinearLayout(bases, dims);
12011208
} else {
1202-
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
1203-
LinearLayout::identity1D(blockN, kCol, dims[1]);
1209+
tile *= LinearLayout::identity1D(blockM, kRow, dims[0]) *
1210+
LinearLayout::identity1D(blockN, kCol, dims[1]);
12041211
}
12051212
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
12061213
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
@@ -1219,14 +1226,18 @@ tensorMemoryScalesToLinearLayout(ArrayRef<int64_t> shape,
12191226
auto kRow = S("row");
12201227
auto kCol = S("col");
12211228
auto dims = standardOutDimNames(ctx, 2);
1222-
// nb. this can be done with
1223-
// ensureLayoutNotSmallerThan/ensureLayoutNotLargerThan but it's a bit less
1224-
// clear IMO
1229+
// See [Zeros in TMEM LinearLayouts]
12251230
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
12261231
// We choose repOrder = [0, 1]
12271232
auto tile =
12281233
LinearLayout::identity1D(std::min<int>(32, shape[0]), kRow, dims[0]) *
1234+
// If shape[0] < 32, we have some rows undefined
1235+
LinearLayout::zeros1D(32 / std::min<int>(32, shape[0]), kRow, dims[0]) *
1236+
// Broadcasting
1237+
LinearLayout::zeros1D(4, kRow, dims[0]) *
12291238
LinearLayout::identity1D(std::min<int>(4, shape[1]), kCol, dims[1]) *
1239+
// If shape[1] < 4, we have some cols undefined
1240+
LinearLayout::zeros1D(4 / std::min<int>(4, shape[1]), kCol, dims[1]) *
12301241
// reps
12311242
LinearLayout::identity1D(std::max<int>(1, shape[0] / 32), kCol, dims[0]) *
12321243
LinearLayout::identity1D(std::max<int>(1, shape[1] / 4), kCol, dims[1]);

0 commit comments

Comments
 (0)