Skip to content

Commit a0fe0e3

Browse files
agron911meta-codesync[bot]
authored andcommitted
[Cherry-pick][RESOLVED] [GLUON] Add histogram support (#7989) (#540)
Summary: ⚠️ **MERGE CONFLICTS DETECTED** ⚠️ This cherry-pick contains merge conflicts that require manual resolution. Original Commit: bfffc33 Original Author: Thomas Raoux Original Date: 2025-08-27 14:14:41 -0700 **Action Required:** 1. Check out this branch locally 2. Resolve the merge conflicts in the affected files 3. Commit the resolved changes 4. Update this PR Original commit message: ``` [GLUON] Add histogram support (#7989) Expose histogram op in gluon. This will be convenient to test more complex layouts. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. The conflicts have been committed with conflict markers for easier resolution. Pull Request resolved: #540 Reviewed By: dshi7 Differential Revision: D85968323 Pulled By: agron911 fbshipit-source-id: f4e30114ab3a83973d5dab42d19633cddb19d49f
1 parent e8b4bf9 commit a0fe0e3

File tree

5 files changed

+60
-0
lines changed

5 files changed

+60
-0
lines changed

python/src/gluon_ir.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,20 @@ void init_gluon_ir(py::module &&m) {
436436
auto dstTy = cast<RankedTensorType>(resultTy);
437437
return isConvertLayoutTrivial(dstTy, value);
438438
})
439+
.def("create_histogram",
440+
[](GluonOpBuilder &self, Value operand, int numBins,
441+
std::optional<Value> mask, Attribute layout) -> Value {
442+
auto *ctx = self.getContext();
443+
auto resultTy =
444+
RankedTensorType::get({static_cast<int64_t>(numBins)},
445+
IntegerType::get(ctx, 32), layout);
446+
if (!mask) {
447+
return self.create<triton::HistogramOp>(resultTy, operand);
448+
} else {
449+
return self.create<triton::HistogramOp>(resultTy, operand,
450+
*mask);
451+
}
452+
})
439453
.def("create_async_copy_global_to_local",
440454
[](GluonOpBuilder &self, Value smem, Value pointer, Value mask,
441455
Value other, tt::CacheModifier cacheModifier,

python/test/gluon/test_frontend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def test_convert_layout(target):
7575
""")
7676

7777

78+
@filecheck_test
79+
@gluon.jit
80+
def test_histogram_frontend():
81+
# CHECK: #blocked = #ttg.blocked
82+
# CHECK-LABEL: test_histogram_frontend
83+
layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0])
84+
x = ttgl.arange(0, 256, layout=layout)
85+
m = x < 128
86+
# CHECK: tt.histogram %{{.*}}, %{{.*}} : tensor<256xi32, #blocked> -> tensor<512xi32, #blocked>
87+
_ = ttgl.histogram(x, 512, mask=m, layout=layout)
88+
89+
7890
@filecheck_test
7991
@gluon.jit
8092
def test_convert_layout_assert_trivial():

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
device_assert,
4747
expand_dims,
4848
full,
49+
histogram,
4950
inline_asm_elementwise,
5051
join,
5152
load,

python/triton/experimental/gluon/language/_core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,27 @@ def full(shape, value, dtype, layout=None, _semantic=None):
396396
return _semantic.full(shape, value, dtype, layout)
397397

398398

399+
@builtin
400+
def histogram(input, num_bins, mask=None, layout=None, _semantic=None, _generator=None):
401+
"""
402+
Compute a histogram of a 1D integer tensor.
403+
404+
Args:
405+
input (tensor): 1D tensor of integer values.
406+
num_bins (int): Number of bins. Bins have width 1 and start at 0.
407+
mask (Optional[tensor]): Boolean mask to exclude elements when False.
408+
layout (DistributedLayout): Destination layout of the output histogram.
409+
410+
Returns:
411+
tensor: 1D int32 tensor of length `num_bins` with the requested layout.
412+
"""
413+
num_bins = _unwrap_if_constexpr(num_bins)
414+
layout = _unwrap_if_constexpr(layout)
415+
if mask is not None:
416+
mask = _semantic.to_tensor(mask)
417+
return _semantic.histogram(input, num_bins, mask, layout)
418+
419+
399420
@builtin
400421
def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None) -> shared_memory_descriptor:
401422
"""

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,18 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
346346
self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
347347
for i in range(len(inputs)))
348348

349+
def histogram(self, input: TensorTy, num_bins: int, mask: TensorTy, layout) -> TensorTy:
350+
_check(len(input.shape) == 1, lambda: "histogram only supports 1D input")
351+
_check(input.dtype.is_int(), lambda: "histogram only supports integer input")
352+
_check(layout is not None, lambda: "histogram requires a destination layout")
353+
if mask is not None:
354+
mask, input = self.broadcast_impl_value(mask, input)
355+
_check(mask.type.scalar.is_bool(), lambda: "Mask must have boolean scalar type")
356+
mask = mask.handle
357+
layout_attr = layout._to_ir(self.builder)
358+
handle = self.builder.create_histogram(input.handle, num_bins, mask, layout_attr)
359+
return self.wrap_tensor(handle, ttgl.int32, [num_bins], layout)
360+
349361
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
350362
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
351363
num_partitions = len(worker_partitions)

0 commit comments

Comments
 (0)