Skip to content

Commit bd88137

Browse files
authored
Calculate block load tile layout dim size from bases (#3971)
Previously we attempted to compute the total size for the 2D block loads according to the tensor layout using input parameters like tensor shape and warp shape. This can be error prone, since the tensor shape is manipulated according to the warp distribution. A cleaner solution is to modify the dimension sizes according to the bases. By loading the dimensions from the tile layout in addition to the bases, we can modify the dimension sizes using the same metrics used to construct the bases. This appears to be giving correct results using the `test_block_load` tests. I did not use this approach initially because I was concerned about the load tile being too "big", but because we incorporate strides in the loads now this approach should faithfully represent the total dimensionality of the loaded data.
1 parent 9d9ad31 commit bd88137

File tree

2 files changed

+14
-18
lines changed

2 files changed

+14
-18
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from triton._internal_testing import is_xpu
77

88

9-
@pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32]])
9+
@pytest.mark.parametrize("M, N",
10+
[[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32], [16, 64]])
1011
@pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"])
1112
@pytest.mark.parametrize("transpose", [True, False])
1213
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,16 +1493,28 @@ struct LoadOpConversion
14931493
// layout.
14941494
auto bases = tileLayout.getBases();
14951495
std::vector<std::vector<int32_t>> newLoadBases;
1496+
1497+
SmallVector<std::pair<StringAttr, int32_t>> outDims;
1498+
for (auto [name, size] :
1499+
llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) {
1500+
outDims.push_back(std::make_pair(name, size));
1501+
}
1502+
assert(outDims[0].first == S("dim0"));
1503+
assert(outDims[1].first == S("dim1"));
1504+
14961505
for (size_t i = 0;
14971506
i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); i++) {
14981507
newLoadBases.push_back({0, static_cast<int>((1 << i) * repKStride *
14991508
numOperandsInnerDimPerLoad)});
1509+
outDims[1].second *= repKStride * numOperandsInnerDimPerLoad;
15001510
}
15011511
for (size_t i = 0; i < llvm::Log2_32(numLoadPerOutRepCluster); i++) {
15021512
newLoadBases.push_back({static_cast<int>((1 << i) * repStride), 0});
1513+
outDims[0].second *= repStride;
15031514
}
15041515
for (size_t i = 0; i < llvm::Log2_32(numRepOuter); i++) {
15051516
newLoadBases.push_back({static_cast<int>((1 << i) * repOuterStride), 0});
1517+
outDims[0].second *= repOuterStride;
15061518
}
15071519

15081520
LLVM_DEBUG({
@@ -1513,23 +1525,6 @@ struct LoadOpConversion
15131525
}
15141526
});
15151527

1516-
SmallVector<std::pair<StringAttr, int32_t>> outDims;
1517-
// Copy the existing dimensions first. This allows us to re-use the existing
1518-
// dim names as well as the sizes should the bases vector be empty (one
1519-
// load).
1520-
for (auto [name, size] :
1521-
llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) {
1522-
outDims.push_back(std::make_pair(name, size));
1523-
}
1524-
if (newLoadBases.size() > 0) {
1525-
outDims[0] = std::make_pair(outDims[0].first, tensorShape[dimOuter]);
1526-
outDims[1] = std::make_pair(
1527-
outDims[1].first,
1528-
std::max(warpShape[dimInner],
1529-
static_cast<unsigned int>(tensorShape[dimInner] *
1530-
repCluster[dimInner])));
1531-
}
1532-
15331528
LLVM_DEBUG({
15341529
llvm::dbgs() << "New tile layout dimensions after adding load bases:\n";
15351530
for (size_t i = 0; i < outDims.size(); i++) {

0 commit comments

Comments
 (0)