Skip to content

Commit ecf1997

Browse files
[HistogramOpToLLVM] Sync from 078954b (#5248)
This PR fixes #5148
1 parent 01605e6 commit ecf1997

File tree

8 files changed

+20
-45
lines changed

8 files changed

+20
-45
lines changed

scripts/skiplist/a770/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/arl-h/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/arl-s/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/default/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/lts/gluon.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5147
2-
python/test/gluon/test_core.py::test_2d_tensor_early_return
3-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
4-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
5-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
61
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
72
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
83
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/mtl/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

scripts/skiplist/xe2/gluon.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5148
2-
python/test/gluon/test_lowerings.py::test_histogram[2048-2-src_layout3-dst_layout3]
3-
python/test/gluon/test_lowerings.py::test_histogram[32-32-src_layout4-dst_layout4]
41
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5149
52
python/test/gluon/test_lowerings.py::test_scan_layouts[r"True-.*"]@regexp
63
python/test/gluon/test_lowerings.py::test_reduce_layouts[r".*-True-.*"]@regexp

third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using namespace mlir;
44
using namespace mlir::triton;
5+
using namespace mlir::triton::gpu;
56

67
// Compute a histogram within a warp. This uses an algorithm by @apgoucher
78
// that does the following:
@@ -20,9 +21,7 @@ static SmallVector<Value> computeWarpLevelHistogram(
2021
Value zero = b.i32_val(0);
2122
int numBits = llvm::Log2_64(numBins);
2223
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
23-
unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType);
24-
unsigned numThreadWithUniqueData = triton::gpu::getThreadsPerWarp(
25-
srcType.getEncoding(), srcType.getShape())[0];
24+
unsigned numElementsPerThreads = getTotalElemsPerThread(srcType);
2625
// The histogram is distributed across threads, each thread owns `numBins /
2726
// numThreadPerWarp` bins.
2827
SmallVector<Value> warpLevelHistogram(numBins / numThreadPerWarp, zero);
@@ -39,10 +38,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
3938
uint64_t fullMaskValue = (1ll << numThreadPerWarp) - 1u;
4039
Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue);
4140
Value mask = fullMask;
42-
// If not all threads have unique data, mask out the redundant ones.
43-
if (numThreadWithUniqueData < numThreadPerWarp) {
44-
mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1);
45-
}
4641
for (int i = 0; i < numBitsLaneId; i++) {
4742
Value updateMask =
4843
b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero),
@@ -94,8 +89,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
9489
Value threadId, int numWarps) {
9590
auto b = TritonLLVMOpBuilder(loc, rewriter);
9691
SmallVector<Value> histogramValues;
97-
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
98-
srcType.getEncoding(), srcType.getShape())[0];
9992
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
10093
// Initialize the shared memory with zeros.
10194
int64_t numElementPerThread =
@@ -110,19 +103,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
110103
}
111104
b.barrier();
112105
Block *afterAtomics = nullptr;
113-
// If some warps have replicated data we need to skip those warps when
114-
// accumulating.
115-
if (numWarpsWithUniqueData < numWarps) {
116-
Block *currentBlock = rewriter.getInsertionBlock();
117-
afterAtomics =
118-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
119-
Block *atomicBlock = rewriter.createBlock(afterAtomics);
120-
rewriter.setInsertionPointToEnd(currentBlock);
121-
Value cond = b.icmp_ult(
122-
threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp));
123-
rewriter.create<LLVM::CondBrOp>(loc, cond, atomicBlock, afterAtomics);
124-
rewriter.setInsertionPointToStart(atomicBlock);
125-
}
126106
// Apply atomic add to update the histogram in shared memory.
127107
for (int i = 0; i < warpLevelHistogram.size(); ++i) {
128108
Value warpLevelHistogramValue = warpLevelHistogram[i];
@@ -208,6 +188,24 @@ struct HistogramOpConversion
208188
loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
209189
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
210190

191+
// Depending on the layout, some threads may have duplicate data. We can
192+
// account for this by calculating a "replication factor" and dividing the
193+
// results by it to avoid overcounting.
194+
auto replicationFactor = numWarps * numThreadsPerWarp;
195+
auto threadsPerWarp = getThreadsPerWarp(srcType);
196+
auto warpsPerCTA =
197+
getWarpsPerCTA(srcType.getEncoding(), srcType.getShape());
198+
replicationFactor /= std::accumulate(
199+
threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>());
200+
replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(),
201+
1, std::multiplies<>());
202+
203+
auto b = TritonLLVMOpBuilder(loc, rewriter);
204+
for (auto i = 0; i < histogramValue.size(); ++i) {
205+
histogramValue[i] =
206+
b.sdiv(histogramValue[i], b.i32_val(replicationFactor));
207+
}
208+
211209
Value results = packLLElements(loc, typeConverter, histogramValue, rewriter,
212210
op.getType());
213211
rewriter.replaceOp(op, results);

0 commit comments

Comments
 (0)