Skip to content

Commit 234cd11

Browse files
saagarjhameta-codesync[bot]
authored andcommitted
[Cherry-pick][RESOLVED] Fix histograms for complex replicated layouts (#7938) (#546)
Summary: ⚠️ **MERGE CONFLICTS DETECTED** ⚠️ This cherry-pick contains merge conflicts that require manual resolution. Original Commit: 078954b Original Author: Saagar Jha Original Date: 2025-08-29 05:05:37 -0700 **Action Required:** 1. Check out this branch locally 2. Resolve the merge conflicts in the affected files 3. Commit the resolved changes 4. Update this PR Original commit message: ``` Fix histograms for complex replicated layouts (#7938) The current histogram code assumes that replication across a warp is done in a way that involves the first n threads having unique data. This is not a valid assumption; in fact the function it calls to get this layout, getThreadsPerWarp, describes one such layout and how it's returned, so the histogram code actually discards that information. To fix this, we actually remove the uniquing code that masks out threads possessing duplicate data. Instead we have everyone participate and adjust for the overcounting that results by computing the "replication factor". This is much easier than computing the correct mask, which is nontrivial in the general case. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. The conflicts have been committed with conflict markers for easier resolution. Pull Request resolved: #546 Reviewed By: agron911 Differential Revision: D85907975 Pulled By: dshi7 fbshipit-source-id: 218021919c1205249fe7a6783a0a186e91a56411
1 parent 36cde5d commit 234cd11

File tree

3 files changed

+56
-24
lines changed

3 files changed

+56
-24
lines changed

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
2525
int numBits = llvm::Log2_64(numBins);
2626
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
2727
unsigned numElementsPerThreads = getTotalElemsPerThread(srcType);
28-
unsigned numThreadWithUniqueData = getThreadsPerWarp(srcType)[0];
2928
// The histogram is distributed across threads, each thread owns `numBins /
3029
// numThreadPerWarp` bins.
3130
SmallVector<Value> warpLevelHistogram(numBins / numThreadPerWarp, zero);
@@ -43,10 +42,6 @@ static SmallVector<Value> computeWarpLevelHistogram(
4342
numThreadPerWarp == 32 ? 0xFFFFFFFF : 0xFFFFFFFFFFFFFFFF;
4443
Value fullMask = b.int_val(numThreadPerWarp, fullMaskValue);
4544
Value mask = fullMask;
46-
// If not all threads have unique data, mask out the redundant ones.
47-
if (numThreadWithUniqueData < numThreadPerWarp) {
48-
mask = b.int_val(numThreadPerWarp, (1ULL << numThreadWithUniqueData) - 1);
49-
}
5045
for (int i = 0; i < numBitsLaneId; i++) {
5146
Value updateMask =
5247
b.select(b.icmp_ne(b.and_(threadId, b.i32_val(1 << i)), zero),
@@ -96,8 +91,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
9691
Value threadId, int numWarps) {
9792
auto b = TritonLLVMOpBuilder(loc, rewriter);
9893
SmallVector<Value> histogramValues;
99-
unsigned numWarpsWithUniqueData = mlir::triton::gpu::getWarpsPerCTA(
100-
srcType.getEncoding(), srcType.getShape())[0];
10194
Value laneId = b.and_(threadId, b.i32_val(numThreadPerWarp - 1));
10295
// Initialize the shared memory with zeros.
10396
int64_t numElementPerThread =
@@ -112,19 +105,6 @@ static SmallVector<Value> computeCrossWarpHistogram(
112105
}
113106
b.barrier();
114107
Block *afterAtomics = nullptr;
115-
// If some warps have replicated data we need to skip those warps when
116-
// accumulating.
117-
if (numWarpsWithUniqueData < numWarps) {
118-
Block *currentBlock = rewriter.getInsertionBlock();
119-
afterAtomics =
120-
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
121-
Block *atomicBlock = rewriter.createBlock(afterAtomics);
122-
rewriter.setInsertionPointToEnd(currentBlock);
123-
Value cond = b.icmp_ult(
124-
threadId, b.i32_val(numWarpsWithUniqueData * numThreadPerWarp));
125-
rewriter.create<LLVM::CondBrOp>(loc, cond, atomicBlock, afterAtomics);
126-
rewriter.setInsertionPointToStart(atomicBlock);
127-
}
128108
// Apply atomic add to update the histogram in shared memory.
129109
for (int i = 0; i < warpLevelHistogram.size(); ++i) {
130110
Value warpLevelHistogramValue = warpLevelHistogram[i];
@@ -209,6 +189,24 @@ struct HistogramOpConversion
209189
loc, rewriter, srcType, baseSharedMemPtr, warpLevelHistogram, numBins,
210190
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
211191

192+
// Depending on the layout, some threads may have duplicate data. We can
193+
// account for this by calculating a "replication factor" and dividing the
194+
// results by it to avoid overcounting.
195+
auto replicationFactor = numWarps * numThreadsPerWarp;
196+
auto threadsPerWarp = getThreadsPerWarp(srcType);
197+
auto warpsPerCTA =
198+
getWarpsPerCTA(srcType.getEncoding(), srcType.getShape());
199+
replicationFactor /= std::accumulate(
200+
threadsPerWarp.begin(), threadsPerWarp.end(), 1, std::multiplies<>());
201+
replicationFactor /= std::accumulate(warpsPerCTA.begin(), warpsPerCTA.end(),
202+
1, std::multiplies<>());
203+
204+
auto b = TritonLLVMOpBuilder(loc, rewriter);
205+
for (auto i = 0; i < histogramValue.size(); ++i) {
206+
histogramValue[i] =
207+
b.sdiv(histogramValue[i], b.i32_val(replicationFactor));
208+
}
209+
212210
Value results = packLLElements(loc, typeConverter, histogramValue, rewriter,
213211
op.getType());
214212
rewriter.replaceOp(op, results);

python/test/gluon/test_lowerings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,40 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr):
240240
])
241241

242242

243+
@pytest.mark.parametrize("M, bins", [[2048, 2], [8, 512], [32, 32]])
244+
@pytest.mark.parametrize("src_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0]), "linear_layout"])
245+
@pytest.mark.parametrize("dst_layout", [ttgl.BlockedLayout([1], [THREADS_PER_WARP], [4], [0])])
246+
def test_histogram(M, bins, src_layout, dst_layout, device):
247+
248+
@gluon.jit
249+
def kernel(x_ptr, z_ptr, M: ttgl.constexpr, B: ttgl.constexpr, src_layout: ttgl.constexpr,
250+
dst_layout: ttgl.constexpr):
251+
offs = ttgl.arange(0, M, layout=src_layout)
252+
x = ttgl.load(x_ptr + offs)
253+
h = ttgl.histogram(x, B, layout=dst_layout)
254+
z_offs = ttgl.arange(0, B, layout=dst_layout)
255+
ttgl.store(z_ptr + z_offs, h)
256+
257+
if src_layout == "linear_layout":
258+
if M == 32:
259+
src_layout = ttgl.DistributedLinearLayout(
260+
reg_bases=[],
261+
lane_bases=[[0], [16], [4], [2], [1]] + [[0]] * (THREADS_PER_WARP >> 6),
262+
warp_bases=[[0], [8]],
263+
block_bases=[],
264+
shape=(M, ),
265+
)
266+
else:
267+
pytest.skip("Linear layout is specialized for 32 elements")
268+
269+
torch.manual_seed(0)
270+
x = torch.randint(0, bins, (M, ), dtype=torch.int32, device=device)
271+
z = torch.zeros((bins, ), dtype=torch.int32, device=device)
272+
z_torch = torch.histc(x.float(), bins=bins, min=0, max=bins - 1).to(torch.int32)
273+
kernel[(1, )](x, z, M, bins, src_layout, dst_layout, num_warps=4)
274+
torch.testing.assert_close(z, z_torch, atol=0, rtol=0)
275+
276+
243277
@pytest.mark.parametrize("M", [64, 128, 256])
244278
@pytest.mark.parametrize("src_layout", _1d_layouts)
245279
@pytest.mark.parametrize("dst_layout", _1d_layouts)

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def mangle(self):
187187

188188
def __hash__(self):
189189
return hash((
190-
tuple(self.reg_bases),
191-
tuple(self.lane_bases),
192-
tuple(self.warp_bases),
193-
tuple(self.block_bases),
190+
tuple(map(tuple, self.reg_bases)),
191+
tuple(map(tuple, self.lane_bases)),
192+
tuple(map(tuple, self.warp_bases)),
193+
tuple(map(tuple, self.block_bases)),
194194
tuple(self.shape),
195195
))
196196

0 commit comments

Comments
 (0)