Skip to content

Commit d57cbee

Browse files
authored
[Gluon] Implement reductions (#7091)
1 parent 750cc53 commit d57cbee

File tree

8 files changed

+147
-6
lines changed

8 files changed

+147
-6
lines changed

python/test/gluon/test_frontend.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import expecttest
2+
from triton.runtime.jit import MockTensor
23
import torch
34
import pytest
45
import re
@@ -705,3 +706,78 @@ def test_math(fresh_knobs):
705706
} loc(#loc)
706707
#loc = loc(unknown)
707708
""")
709+
710+
711+
@gluon.jit
712+
def pair_add(a0, a1, b0, b1):
713+
return a0 + b0, a1 + b1
714+
715+
716+
@gluon.jit
717+
def reduce_kernel(out):
718+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
719+
a = ttgl.full([16, 16], 1, ttgl.float32, layout)
720+
b = ttgl.full([16, 16], 2, ttgl.float32, layout)
721+
s0 = ttgl.sum(a, 0)
722+
ttgl.static_assert(s0.type.layout == ttgl.SliceLayout(0, layout))
723+
s1 = ttgl.sum(a, 1)
724+
ttgl.static_assert(s1.type.layout == ttgl.SliceLayout(1, layout))
725+
726+
scalar = ttgl.max(s0, 0)
727+
ttgl.static_assert(scalar.type == ttgl.float32)
728+
729+
s1 = ttgl.convert_layout(s1, s0.type.layout)
730+
731+
pairs = ttgl.reduce((a, b), 0, pair_add)
732+
ttgl.static_assert(pairs[0].type.layout == ttgl.SliceLayout(0, layout))
733+
ttgl.static_assert(pairs[1].type.layout == ttgl.SliceLayout(0, layout))
734+
result = scalar + s1 + pairs[0] + pairs[1]
735+
tl.store(out + ttgl.arange(0, 16, s0.type.layout), result)
736+
737+
738+
def test_reduce(fresh_knobs):
739+
knobs.compilation.disable_line_info = True
740+
741+
h = reduce_kernel.warmup(MockTensor(ttgl.float32), sanitize_overflow=False, grid=(1, ))
742+
expecttest.assert_expected_inline(
743+
anonymize_ir(h.asm["ttgir"]), """\
744+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
745+
#loc = loc(unknown)
746+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
747+
tt.func public @reduce_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} {
748+
%cst = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
749+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
750+
%0 = "tt.reduce"(%cst_0) <{axis = 0 : i32}> ({
751+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
752+
%12 = arith.addf %arg1, %arg2 : f32 loc(#loc)
753+
tt.reduce.return %12 : f32 loc(#loc)
754+
}) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
755+
%1 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({
756+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
757+
%12 = arith.addf %arg1, %arg2 : f32 loc(#loc)
758+
tt.reduce.return %12 : f32 loc(#loc)
759+
}) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc)
760+
%2 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
761+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
762+
%12 = arith.maxnumf %arg1, %arg2 : f32 loc(#loc)
763+
tt.reduce.return %12 : f32 loc(#loc)
764+
}) : (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 loc(#loc)
765+
%3 = ttg.convert_layout %1 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
766+
%4:2 = "tt.reduce"(%cst_0, %cst) <{axis = 0 : i32}> ({
767+
^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown), %arg3: f32 loc(unknown), %arg4: f32 loc(unknown)):
768+
%12 = arith.addf %arg1, %arg3 : f32 loc(#loc)
769+
%13 = arith.addf %arg2, %arg4 : f32 loc(#loc)
770+
tt.reduce.return %12, %13 : f32, f32 loc(#loc)
771+
}) : (tensor<16x16xf32, #blocked>, tensor<16x16xf32, #blocked>) -> (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) loc(#loc)
772+
%5 = tt.splat %2 : f32 -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
773+
%6 = arith.addf %5, %3 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
774+
%7 = arith.addf %6, %4#0 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
775+
%8 = arith.addf %7, %4#1 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
776+
%9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
777+
%10 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
778+
%11 = tt.addptr %10, %9 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
779+
tt.store %11, %8 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
780+
tt.return loc(#loc)
781+
} loc(#loc)
782+
} loc(#loc)
783+
""")

python/triton/compiler/code_generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,7 +1464,7 @@ def ret(self, node: ast.Call):
14641464
}
14651465

14661466

1467-
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
1467+
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
14681468
arg_types = [None] * len(fn.arg_names)
14691469
for k, v in src.signature.items():
14701470
idx = fn.arg_names.index(k)
@@ -1479,7 +1479,7 @@ def ast_to_ttir(fn, src, context, options, codegen_fns, module_map):
14791479
proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
14801480
generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(proxy), jit_fn=fn,
14811481
is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
1482-
codegen_fns=codegen_fns, module_map=module_map)
1482+
codegen_fns=codegen_fns, module_map=module_map, module=module)
14831483
generator.visit(fn.parse())
14841484
ret = generator.module
14851485
# module takes ownership of the context

python/triton/experimental/gluon/_runtime.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
1818

1919
def make_ir(self, options, codegen_fns, module_map, context):
2020
from triton.compiler.compiler import make_backend
21-
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
22-
module_map=module_map)
21+
2322
builder = ir.builder(context)
23+
module = builder.create_module()
24+
25+
# Assign module attributes eagerly, as they are needed to verify layouts
2426
target = triton.runtime.driver.active.get_current_target()
2527
backend = make_backend(target)
2628
target = backend.get_target_name(options)
@@ -30,6 +32,9 @@ def make_ir(self, options, codegen_fns, module_map, context):
3032
module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(32))
3133
if options.maxnreg is not None:
3234
module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
35+
36+
module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
37+
module_map=module_map, module=module)
3338
return module
3439

3540

python/triton/experimental/gluon/language/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
from ._layouts import __all__ as __layouts_all
55
from ._math import * # NOQA: F403
66
from ._math import __all__ as __math_all
7+
from ._standard import * # NOQA: F403
8+
from ._standard import __all__ as __standard_all
79

810
from . import nvidia
911

1012
__all__ = [
1113
*__core_all,
1214
*__layouts_all,
1315
*__math_all,
16+
*__standard_all,
1417
"nvidia",
1518
]

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@
4343

4444
_IMPORT_FROM_TRITON: List[str] = [
4545
"expand_dims", # NOQA: F822
46-
"program_id", # NOQA: F822
4746
"load", # NOQA: F822
47+
"program_id", # NOQA: F822
48+
"reduce", # NOQA: F822
49+
"static_assert", # NOQA: F822
4850
"store", # NOQA: F822
4951
"to_tensor", # NOQA: F822
5052
]

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,41 @@ def memdesc_reinterpret(self, mem_desc, dtype, shape, layout):
180180
handle = self.builder.create_memdesc_reinterpret(ty.to_ir(self.builder), mem_desc.handle)
181181
return ttgl.shared_memory_descriptor(handle, **ty.__dict__)
182182

183+
def wrap_tensor(self, x, scalar_ty, ret_shape, layout):
184+
if ret_shape:
185+
res_ty = ttgl.distributed_type(scalar_ty, ret_shape, layout)
186+
else:
187+
res_ty = scalar_ty
188+
return self.tensor(x, res_ty)
189+
190+
@staticmethod
191+
def _check_same_layout(xs):
192+
for x in xs:
193+
_check(isinstance(x.type, ttgl.distributed_type), lambda: f"expected distributed_type but got: {x.type!r}")
194+
layouts = [x.type.layout for x in xs]
195+
l0 = layouts[0]
196+
_check(all(l == l0 for l in layouts[1:]),
197+
lambda: f"Expected inputs to have matching layouts, but got: {layouts}")
198+
199+
def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
200+
_check(axis is not None, lambda: "All-reduce is not yet implemented in gluon")
201+
# get result shape
202+
shape = inputs[0].type.shape
203+
rank = len(shape)
204+
_check(0 <= axis < rank, lambda: f"expected reduction axis to be in the range [0, {rank}) but got {axis}")
205+
self._check_same_layout(inputs)
206+
ret_shape = [s for i, s in enumerate(shape) if i != axis]
207+
ret_layout = SliceLayout(axis, inputs[0].type.layout)
208+
assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
209+
210+
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
211+
region_builder_fn(reduce_op)
212+
assert reduce_op.verify()
213+
214+
return tuple(
215+
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
216+
for i in range(len(inputs)))
217+
183218
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
184219
worker_num_regs: Sequence[int], generator):
185220
num_partitions = len(worker_partitions)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# flake8: noqa
2+
import triton
3+
import triton.language.standard as tl_standard
4+
from .._runtime import jit
5+
6+
_IMPORT_FROM_TRITON = [
7+
"sum",
8+
"max",
9+
"min",
10+
"reduce_or",
11+
"xor_sum",
12+
]
13+
14+
__all__ = _IMPORT_FROM_TRITON
15+
16+
for name in _IMPORT_FROM_TRITON:
17+
# Convert JITFunction -> GluonJitFunction
18+
fn = getattr(tl_standard, name)
19+
assert isinstance(fn, triton.runtime.JITFunction)
20+
globals()[name] = jit(fn.fn)

python/triton/language/semantic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,7 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
16591659

16601660
reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
16611661
region_builder_fn(reduce_op)
1662-
reduce_op.verify()
1662+
assert reduce_op.verify()
16631663

16641664
return tuple(
16651665
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))

0 commit comments

Comments
 (0)