@@ -11,8 +11,9 @@ using namespace mlir::triton;
11
11
// only popcount those.
12
12
static SmallVector<Value> computeWarpLevelHistogram (
13
13
Location loc, RankedTensorType srcType, SmallVector<Value> &srcValues,
14
- int numBins, int numThreadPerWarp, Value threadId,
15
- ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) {
14
+ SmallVector<Value> &maskValues, int numBins, int numThreadPerWarp,
15
+ Value threadId, ConversionPatternRewriter &rewriter,
16
+ const TargetInfoBase &targetInfo) {
16
17
auto b = TritonLLVMOpBuilder (loc, rewriter);
17
18
assert (numBins % numThreadPerWarp == 0 &&
18
19
" numBins must be divisible by numThreadPerWarp" );
@@ -49,6 +50,14 @@ static SmallVector<Value> computeWarpLevelHistogram(
49
50
mask = b.and_ (
50
51
mask, b.xor_ (ballotBits[i + numBits - numBitsLaneId], updateMask));
51
52
}
53
+ // save a ballot bit to capture the input mask
54
+ Value inputMaskBit = fullMask;
55
+ if (maskValues.size () > 0 ) {
56
+ inputMaskBit = targetInfo.ballot (rewriter, loc, int_ty (numThreadPerWarp),
57
+ maskValues[i]);
58
+ }
59
+ // mask out the values for which input mask is invalid
60
+ mask = b.and_ (mask, inputMaskBit);
52
61
// at this point, 'mask' tells you which elements are in a bin owned by this
53
62
// thread.
54
63
for (int k = 0 ; k < warpLevelHistogram.size (); k++) {
@@ -158,6 +167,12 @@ struct HistogramOpConversion
158
167
Value input = adaptor.getSrc ();
159
168
auto typeConverter = getTypeConverter ();
160
169
SmallVector<Value> srcValues = unpackLLElements (loc, input, rewriter);
170
+
171
+ Value llMask = adaptor.getMask ();
172
+ SmallVector<Value> maskValues;
173
+ if (llMask)
174
+ maskValues = unpackLLElements (loc, llMask, rewriter);
175
+
161
176
int numBins = op.getType ().getDimSize (0 );
162
177
auto mod = op->getParentOfType <ModuleOp>();
163
178
int numThreadsPerWarp =
@@ -173,8 +188,8 @@ struct HistogramOpConversion
173
188
auto srcType = op.getSrc ().getType ();
174
189
// First compute a warp local histogram based on values owned by each warps.
175
190
SmallVector<Value> warpLevelHistogram = computeWarpLevelHistogram (
176
- loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter ,
177
- targetInfo);
191
+ loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp ,
192
+ threadId, rewriter, targetInfo);
178
193
179
194
// Then use atomic to update the histogram in shared memory.
180
195
// TODO: we could skip this for cases with num_warps=1 as long as we can
0 commit comments