Skip to content

Commit 942c79f

Browse files
[HistogramOpToLLVM] Sync from 2a10b48
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9afa34e commit 942c79f

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ using namespace mlir::triton;
1111
// only popcount those.
1212
static SmallVector<Value> computeWarpLevelHistogram(
1313
Location loc, RankedTensorType srcType, SmallVector<Value> &srcValues,
14-
int numBins, int numThreadPerWarp, Value threadId,
15-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) {
14+
SmallVector<Value> &maskValues, int numBins, int numThreadPerWarp,
15+
Value threadId, ConversionPatternRewriter &rewriter,
16+
const TargetInfoBase &targetInfo) {
1617
auto b = TritonLLVMOpBuilder(loc, rewriter);
1718
assert(numBins % numThreadPerWarp == 0 &&
1819
"numBins must be divisible by numThreadPerWarp");
@@ -49,6 +50,14 @@ static SmallVector<Value> computeWarpLevelHistogram(
4950
mask = b.and_(
5051
mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask));
5152
}
53+
// save a ballot bit to capture the input mask
54+
Value inputMaskBit = fullMask;
55+
if (maskValues.size() > 0) {
56+
inputMaskBit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp),
57+
maskValues[i]);
58+
}
59+
// mask out the values for which input mask is invalid
60+
mask = b.and_(mask, inputMaskBit);
5261
// at this point, 'mask' tells you which elements are in a bin owned by this
5362
// thread.
5463
for (int k = 0; k < warpLevelHistogram.size(); k++) {
@@ -158,6 +167,12 @@ struct HistogramOpConversion
158167
Value input = adaptor.getSrc();
159168
auto typeConverter = getTypeConverter();
160169
SmallVector<Value> srcValues = unpackLLElements(loc, input, rewriter);
170+
171+
Value llMask = adaptor.getMask();
172+
SmallVector<Value> maskValues;
173+
if (llMask)
174+
maskValues = unpackLLElements(loc, llMask, rewriter);
175+
161176
int numBins = op.getType().getDimSize(0);
162177
auto mod = op->getParentOfType<ModuleOp>();
163178
int numThreadsPerWarp =
@@ -173,8 +188,8 @@ struct HistogramOpConversion
173188
auto srcType = op.getSrc().getType();
174189
// First compute a warp local histogram based on values owned by each warps.
175190
SmallVector<Value> warpLevelHistogram = computeWarpLevelHistogram(
176-
loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter,
177-
targetInfo);
191+
loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp,
192+
threadId, rewriter, targetInfo);
178193

179194
// Then use atomic to update the histogram in shared memory.
180195
// TODO: we could skip this for cases with num_warps=1 as long as we can

0 commit comments

Comments
 (0)