22
33using namespace mlir ;
44using namespace mlir ::triton;
5+ using namespace mlir ::triton::gpu;
56
67// Compute a histogram within a warp. This uses an algorithm by @apgoucher
78// that does the following:
@@ -20,9 +21,7 @@ static SmallVector<Value> computeWarpLevelHistogram(
2021 Value zero = b.i32_val (0 );
2122 int numBits = llvm::Log2_64 (numBins);
2223 int numBitsLaneId = llvm::Log2_64 (numThreadPerWarp);
23- unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread (srcType);
24- unsigned numThreadWithUniqueData = triton::gpu::getThreadsPerWarp (
25- srcType.getEncoding (), srcType.getShape ())[0 ];
24+ unsigned numElementsPerThreads = getTotalElemsPerThread (srcType);
2625 // The histogram is distributed across threads, each thread owns `numBins /
2726 // numThreadPerWarp` bins.
2827 SmallVector<Value> warpLevelHistogram (numBins / numThreadPerWarp, zero);
@@ -39,10 +38,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
3938 uint64_t fullMaskValue = (1ll << numThreadPerWarp) - 1u ;
4039 Value fullMask = b.int_val (numThreadPerWarp, fullMaskValue);
4140 Value mask = fullMask;
42- // If not all threads have unique data, mask out the redundant ones.
43- if (numThreadWithUniqueData < numThreadPerWarp) {
44- mask = b.int_val (numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1 );
45- }
4641 for (int i = 0 ; i < numBitsLaneId; i++) {
4742 Value updateMask =
4843 b.select (b.icmp_ne (b.and_ (threadId, b.i32_val (1 << i)), zero),
@@ -94,8 +89,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
9489 Value threadId, int numWarps) {
9590 auto b = TritonLLVMOpBuilder (loc, rewriter);
9691 SmallVector<Value> histogramValues;
97- unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA (
98- srcType.getEncoding (), srcType.getShape ())[0 ];
9992 Value laneId = b.and_ (threadId, b.i32_val (numThreadPerWarp - 1 ));
10093 // Initialize the shared memory with zeros.
10194 int64_t numElementPerThread =
@@ -110,19 +103,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
110103 }
111104 b.barrier ();
112105 Block *afterAtomics = nullptr ;
113- // If some warps have replicated data we need to skip those warps when
114- // accumulating.
115- if (numWarpsWithUniqueData < numWarps) {
116- Block *currentBlock = rewriter.getInsertionBlock ();
117- afterAtomics =
118- rewriter.splitBlock (currentBlock, rewriter.getInsertionPoint ());
119- Block *atomicBlock = rewriter.createBlock (afterAtomics);
120- rewriter.setInsertionPointToEnd (currentBlock);
121- Value cond = b.icmp_ult (
122- threadId, b.i32_val (numWarpsWithUniqueData * numThreadPerWarp));
123- rewriter.create <LLVM::CondBrOp>(loc, cond, atomicBlock, afterAtomics);
124- rewriter.setInsertionPointToStart (atomicBlock);
125- }
126106 // Apply atomic add to update the histogram in shared memory.
127107 for (int i = 0 ; i < warpLevelHistogram.size (); ++i) {
128108 Value warpLevelHistogramValue = warpLevelHistogram[i];
@@ -208,6 +188,24 @@ struct HistogramOpConversion
208188 loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
209189 numThreadsPerWarp, innerDimIndices, threadId, numWarps);
210190
191+ // Depending on the layout, some threads may have duplicate data. We can
192+ // account for this by calculating a "replication factor" and dividing the
193+ // results by it to avoid overcounting.
194+ auto replicationFactor = numWarps * numThreadsPerWarp;
195+ auto threadsPerWarp = getThreadsPerWarp (srcType);
196+ auto warpsPerCTA =
197+ getWarpsPerCTA (srcType.getEncoding (), srcType.getShape ());
198+ replicationFactor /= std::accumulate (
199+ threadsPerWarp.begin (), threadsPerWarp.end (), 1 , std::multiplies<>());
200+ replicationFactor /= std::accumulate (warpsPerCTA.begin (), warpsPerCTA.end (),
201+ 1 , std::multiplies<>());
202+
203+ auto b = TritonLLVMOpBuilder (loc, rewriter);
204+ for (auto i = 0 ; i < histogramValue.size (); ++i) {
205+ histogramValue[i] =
206+ b.sdiv (histogramValue[i], b.i32_val (replicationFactor));
207+ }
208+
211209 Value results = packLLElements (loc, typeConverter, histogramValue, rewriter,
212210 op.getType ());
213211 rewriter.replaceOp (op, results);
0 commit comments