Skip to content

Commit 2a10b48

Browse files
authored
Add support for masked histograms (#6695)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 5e00f35 commit 2a10b48

File tree

10 files changed

+129
-31
lines changed

10 files changed

+129
-31
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -865,19 +865,24 @@ def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
865865
//
866866
// Histogram Op
867867
//
868-
def TT_HistogramOp : TT_Op<"histogram", [Pure]> {
868+
def TT_HistogramOp : TT_Op<"histogram", [Pure,
869+
TypesMatchWith<"mask type matches src type",
870+
"src", "mask", "getI1SameShape($_self)",
871+
"($_op.getOperands().size() <= 1) || std::equal_to<>()">]> {
869872
let summary = "return a histogram of the inputs.";
870873
let description = [{
871874
Return the histogram of the input tensor. The number of bins is equal to
872875
the dimension of the output tensor. Each bins has a width of 1 and bins
873876
start at 0.
874877
}];
875878

876-
let arguments = (ins TT_IntTensor:$src);
879+
let arguments = (ins TT_IntTensor:$src,
880+
Optional<TT_BoolLike>:$mask);
881+
877882
let results = (outs TT_IntTensor:$result);
878883

879884
let assemblyFormat = [{
880-
$src attr-dict `:` type($src) `->` type($result)
885+
$src (`,` $mask^)? attr-dict `:` type($src) `->` type($result)
881886
}];
882887
}
883888

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ using namespace mlir::triton::gpu;
1515
// only popcount those.
1616
static SmallVector<Value> computeWarpLevelHistogram(
1717
Location loc, RankedTensorType srcType, SmallVector<Value> &srcValues,
18-
int numBins, int numThreadPerWarp, Value threadId,
19-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo) {
18+
SmallVector<Value> &maskValues, int numBins, int numThreadPerWarp,
19+
Value threadId, ConversionPatternRewriter &rewriter,
20+
const TargetInfoBase &targetInfo) {
2021
auto b = TritonLLVMOpBuilder(loc, rewriter);
2122
assert(numBins % numThreadPerWarp == 0 &&
2223
"numBins must be divisible by numThreadPerWarp");
@@ -53,6 +54,14 @@ static SmallVector<Value> computeWarpLevelHistogram(
5354
mask = b.and_(
5455
mask, b.xor_(ballotBits[i + numBits - numBitsLaneId], updateMask));
5556
}
57+
// save a ballot bit to capture the input mask
58+
Value inputMaskBit = fullMask;
59+
if (maskValues.size() > 0) {
60+
inputMaskBit = targetInfo.ballot(rewriter, loc, int_ty(numThreadPerWarp),
61+
maskValues[i]);
62+
}
63+
// mask out the values for which input mask is invalid
64+
mask = b.and_(mask, inputMaskBit);
5665
// at this point, 'mask' tells you which elements are in a bin owned by this
5766
// thread.
5867
for (int k = 0; k < warpLevelHistogram.size(); k++) {
@@ -159,6 +168,12 @@ struct HistogramOpConversion
159168
Value input = adaptor.getSrc();
160169
auto typeConverter = getTypeConverter();
161170
SmallVector<Value> srcValues = unpackLLElements(loc, input, rewriter);
171+
172+
Value llMask = adaptor.getMask();
173+
SmallVector<Value> maskValues;
174+
if (llMask)
175+
maskValues = unpackLLElements(loc, llMask, rewriter);
176+
162177
int numBins = op.getType().getDimSize(0);
163178
auto mod = op->getParentOfType<ModuleOp>();
164179
int numThreadsPerWarp =
@@ -174,8 +189,8 @@ struct HistogramOpConversion
174189
auto srcType = op.getSrc().getType();
175190
// First compute a warp local histogram based on values owned by each warps.
176191
SmallVector<Value> warpLevelHistogram = computeWarpLevelHistogram(
177-
loc, srcType, srcValues, numBins, numThreadsPerWarp, threadId, rewriter,
178-
targetInfo);
192+
loc, srcType, srcValues, maskValues, numBins, numThreadsPerWarp,
193+
threadId, rewriter, targetInfo);
179194

180195
// Then use atomic to update the histogram in shared memory.
181196
// TODO: we could skip this for cases with num_warps=1 as long as we can

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,24 @@ struct CanonicalizeConvertFromHistogram
119119
mlir::LogicalResult
120120
matchAndRewrite(triton::HistogramOp op,
121121
PatternRewriter &rewriter) const override {
122-
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
123-
if (!convert)
122+
auto src = op.getSrc();
123+
auto convert = src.getDefiningOp<ConvertLayoutOp>();
124+
if (!convert) {
124125
return failure();
126+
}
127+
src = convert.getSrc();
128+
129+
// If mask is present, convert the layout of mask to match new src layout
130+
auto mask = op.getMask();
131+
if (mask) {
132+
auto sharedType = getI1SameShape(src.getType());
133+
rewriter.setInsertionPoint(op);
134+
mask = rewriter.create<ConvertLayoutOp>(op.getLoc(), sharedType, mask);
135+
}
136+
125137
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
126-
op, op->getResult(0).getType(), convert.getSrc());
127-
return mlir::success();
138+
op, op->getResult(0).getType(), src, mask);
139+
return success();
128140
}
129141
};
130142

@@ -263,7 +275,8 @@ struct CanonicalizeConvertFromConvert
263275
// For histogram ops the input and output layouts are independent, so we
264276
// can always fold convert into the histogram op.
265277
rewriter.replaceOpWithNewOp<HistogramOp>(op, op->getResult(0).getType(),
266-
histogram.getSrc());
278+
histogram.getSrc(),
279+
histogram.getMask());
267280
return success();
268281
}
269282

python/src/ir.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,12 +1667,21 @@ void init_triton_ir(py::module &&m) {
16671667
return self.create<ub::PoisonOp>(type);
16681668
})
16691669
.def("create_histogram",
1670-
[](TritonOpBuilder &self, Value operand, int numBins) -> Value {
1671-
return self.create<HistogramOp>(
1672-
RankedTensorType::get(
1673-
{static_cast<int64_t>(numBins)},
1674-
IntegerType::get(operand.getContext(), 32)),
1675-
operand);
1670+
[](TritonOpBuilder &self, Value operand, int numBins,
1671+
std::optional<Value> mask) -> Value {
1672+
if (!mask) {
1673+
return self.create<HistogramOp>(
1674+
RankedTensorType::get(
1675+
{static_cast<int64_t>(numBins)},
1676+
IntegerType::get(operand.getContext(), 32)),
1677+
operand);
1678+
} else {
1679+
return self.create<HistogramOp>(
1680+
RankedTensorType::get(
1681+
{static_cast<int64_t>(numBins)},
1682+
IntegerType::get(operand.getContext(), 32)),
1683+
operand, *mask);
1684+
}
16761685
})
16771686
.def("create_gather",
16781687
[](TritonOpBuilder &self, Value src, Value indices, int axis)

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2858,6 +2858,36 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
28582858
assert (z_torch == z).all()
28592859

28602860

2861+
# ------------------------
2862+
# test histogram with mask
2863+
# ------------------------
2864+
2865+
2866+
@pytest.mark.interpreter
2867+
@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]])
2868+
def test_histogram_mask(M, N, device):
2869+
2870+
@triton.jit
2871+
def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
2872+
offset1 = tl.arange(0, 2 * M)
2873+
offset2 = tl.arange(0, N)
2874+
mask = offset1 < M
2875+
x = tl.load(x_ptr + offset1)
2876+
z = tl.histogram(x, N, mask)
2877+
tl.store(z_ptr + offset2, z)
2878+
2879+
torch.manual_seed(17)
2880+
x1 = torch.randint(0, N, (M, ), device=device, dtype=torch.int32)
2881+
x = torch.cat((x1, x1), 0)
2882+
z = torch.empty(N, dtype=torch.int32, device=device)
2883+
# torch.histc does not work when the input type is not float and the device is CPU
2884+
# https://github.com/pytorch/pytorch/issues/74236
2885+
# This is a workload by converting the input to float
2886+
z_torch = torch.histc(x1.float(), bins=N, min=0, max=N - 1)
2887+
histogram_kernel[(1, )](x, z, M=M, N=N)
2888+
assert (z_torch == z).all()
2889+
2890+
28612891
@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)])
28622892
def test_scan_1d(M, N, device):
28632893

python/triton/language/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,17 +2733,22 @@ def make_combine_region(scan_op):
27332733

27342734
@_tensor_member_fn
27352735
@builtin
2736-
def histogram(input, num_bins, _builder=None, _generator=None):
2736+
def histogram(input, num_bins, mask=None, _builder=None, _generator=None):
27372737
"""computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
27382738
27392739
:param input: the input tensor
27402740
:type input: Tensor
27412741
:param num_bins: number of histogram bins
27422742
:type num_bins: int
2743+
:param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
2744+
:type mask: Block of `triton.int1`, optional
27432745
27442746
"""
27452747
num_bins = _unwrap_if_constexpr(num_bins)
2746-
return semantic.histogram(input, num_bins, _builder)
2748+
mask = _unwrap_if_constexpr(mask)
2749+
if mask is not None:
2750+
mask = semantic.to_tensor(mask, _builder)
2751+
return semantic.histogram(input, num_bins, mask, _builder)
27472752

27482753

27492754
@_tensor_member_fn

python/triton/language/semantic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,10 +1805,15 @@ def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) ->
18051805
# ===----------------------------------------------------------------------===
18061806

18071807

1808-
def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor:
1808+
def histogram(input: tl.tensor, num_bins: int, mask: Optional[tl.tensor], builder: ir.builder) -> tl.tensor:
18091809
assert len(input.shape) == 1, "histogram only supports 1D input"
18101810
assert input.dtype.is_int(), "histogram only supports integer input"
1811-
return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, [num_bins]))
1811+
if mask is not None:
1812+
mask = broadcast_impl_shape(mask, input.shape, builder)
1813+
if not mask.type.scalar.is_bool():
1814+
raise ValueError("Mask must have boolean scalar type")
1815+
mask = mask.handle
1816+
return tl.tensor(builder.create_histogram(input.handle, num_bins, mask), tl.block_type(tl.int32, [num_bins]))
18121817

18131818

18141819
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:

python/triton/runtime/interpreter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,15 @@ def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
598598
def create_make_range(self, ret_ty, start, stop):
599599
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
600600

601-
def create_histogram(self, data, bins):
602-
return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32)
601+
def create_histogram(self, data, bins, mask):
602+
if mask is None:
603+
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
604+
# force all masked elements to zero
605+
data = np.where(mask.data, data.data, np.zeros_like(data.data))
606+
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
607+
# remove overcounted elements
608+
histogram[0] -= np.logical_not(mask.data).sum()
609+
return TensorHandle(histogram, tl.int32)
603610

604611
def create_gather(self, src, indices, axis):
605612
return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)

test/Triton/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ tt.func @histogram(%0: tensor<512xi32>) {
243243
tt.return
244244
}
245245

246+
// CHECK-LABEL: masked_histogram
247+
tt.func @masked_histogram(%0: tensor<512xi32>, %1: tensor<512xi1>) {
248+
// CHECK: tt.histogram %{{.+}}, %{{.+}} : tensor<512xi32> -> tensor<16xi32>
249+
%2 = tt.histogram %0, %1 : tensor<512xi32> -> tensor<16xi32>
250+
tt.return
251+
}
252+
246253
// CHECK-LABEL: descriptor_load
247254
tt.func @descriptor_load(%0: !tt.tensordesc<tensor<128xf32>>) {
248255
// CHECK: tt.descriptor_load %{{.+}}[%{{.+}}] : !tt.tensordesc<tensor<128xf32>> -> tensor<128xf32>

test/TritonGPU/canonicalize.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,22 @@ tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) ->
8181
// -----
8282

8383
// CHECK-LABEL: @test_canonicalize_convert_histogram
84-
// CHECK-SAME: (%[[ARG:.+]]: tensor<256xi32
85-
// CHECK-NOT: ttg.convert_layout
86-
// CHECK: %[[V:.+]] = tt.histogram %[[ARG]]
84+
// CHECK-SAME: (%[[SRC:.+]]: tensor<256xi32
85+
// CHECK-SAME: %[[MASK:.+]]: tensor<256xi1
86+
// CHECK: %[[M:.+]] = ttg.convert_layout %[[MASK]]
87+
// CHECK: %[[V:.+]] = tt.histogram %[[SRC]], %[[M]]
8788
// CHECK-NOT: ttg.convert_layout
8889
// CHECK: tt.return %[[V]]
8990
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
9091
#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
9192
#blocked2 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
9293
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
93-
tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>) -> tensor<512xi32, #blocked2> {
94+
tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>, %arg1: tensor<256xi1, #blocked2>) -> tensor<512xi32, #blocked2> {
9495
%0 = ttg.convert_layout %arg0 : tensor<256xi32, #blocked1> -> tensor<256xi32, #blocked>
95-
%1 = tt.histogram %0 : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked>
96-
%2 = ttg.convert_layout %1 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2>
97-
tt.return %2 : tensor<512xi32, #blocked2>
96+
%1 = ttg.convert_layout %arg1 : tensor<256xi1, #blocked2> -> tensor<256xi1, #blocked>
97+
%2 = tt.histogram %0, %1 : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked>
98+
%3 = ttg.convert_layout %2 : tensor<512xi32, #blocked> -> tensor<512xi32, #blocked2>
99+
tt.return %3 : tensor<512xi32, #blocked2>
98100
}
99101
} // end module
100102

0 commit comments

Comments
 (0)