Skip to content

Commit 88a4c70

Browse files
authored
[Gluon] Import math functions from triton (#7089)
1 parent 0455876 commit 88a4c70

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

python/test/gluon/test_frontend.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,73 @@ def test_broadcast(fresh_knobs):
635635
} loc(#loc)
636636
#loc = loc(unknown)
637637
""")
638+
639+
640+
@gluon.jit
641+
def math_kernel():
642+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
643+
a = ttgl.full([16, 16], 1, ttgl.float32, layout)
644+
b = ttgl.full([16, 16], 2, ttgl.float32, layout)
645+
c = ttgl.full([16, 16], 4, ttgl.float32, layout)
646+
d = ttgl.full([16, 16], 1, ttgl.int32, layout)
647+
e = ttgl.full([16, 16], 1, ttgl.int32, layout)
648+
ttgl.umulhi(d, e)
649+
ttgl.exp(a)
650+
ttgl.exp2(a)
651+
ttgl.log(a)
652+
ttgl.log2(a)
653+
ttgl.cos(a)
654+
ttgl.sin(a)
655+
ttgl.sqrt(a)
656+
ttgl.sqrt_rn(a)
657+
ttgl.rsqrt(a)
658+
ttgl.abs(a)
659+
ttgl.fdiv(a, b)
660+
ttgl.div_rn(a, b)
661+
ttgl.erf(a)
662+
ttgl.floor(a)
663+
ttgl.ceil(a)
664+
ttgl.fma(a, b, c)
665+
666+
667+
def test_math(fresh_knobs):
668+
knobs.compilation.disable_line_info = True
669+
670+
h = math_kernel.warmup(sanitize_overflow=False, grid=(1, ))
671+
expecttest.assert_expected_inline(
672+
anonymize_ir(h.asm["source"]), """\
673+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
674+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
675+
tt.func public @math_kernel() attributes {noinline = false} {
676+
%cst = arith.constant 1.000000e+00 : f32 loc(#loc)
677+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
678+
%cst_1 = arith.constant 2.000000e+00 : f32 loc(#loc)
679+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
680+
%cst_3 = arith.constant 4.000000e+00 : f32 loc(#loc)
681+
%cst_4 = arith.constant dense<4.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
682+
%c1_i32 = arith.constant 1 : i32 loc(#loc)
683+
%cst_5 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
684+
%c1_i32_6 = arith.constant 1 : i32 loc(#loc)
685+
%cst_7 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
686+
%0 = tt.mulhiui %cst_5, %cst_7 : tensor<16x16xi32, #blocked> loc(#loc)
687+
%1 = math.exp %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
688+
%2 = math.exp2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
689+
%3 = math.log %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
690+
%4 = math.log2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
691+
%5 = math.cos %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
692+
%6 = math.sin %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
693+
%7 = math.sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
694+
%8 = tt.precise_sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
695+
%9 = math.rsqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
696+
%10 = math.absf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
697+
%11 = arith.divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
698+
%12 = tt.precise_divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
699+
%13 = math.erf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
700+
%14 = math.floor %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
701+
%15 = math.ceil %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
702+
%16 = math.fma %cst_0, %cst_2, %cst_4 : tensor<16x16xf32, #blocked> loc(#loc)
703+
tt.return loc(#loc)
704+
} loc(#loc)
705+
} loc(#loc)
706+
#loc = loc(unknown)
707+
""")

python/triton/experimental/gluon/language/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
from ._core import __all__ as __core_all
33
from ._layouts import * # NOQA: F403
44
from ._layouts import __all__ as __layouts_all
5+
from ._math import * # NOQA: F403
6+
from ._math import __all__ as __math_all
57

68
from . import nvidia
79

810
__all__ = [
911
*__core_all,
1012
*__layouts_all,
13+
*__math_all,
1114
"nvidia",
1215
]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# flake8: noqa
2+
import triton.language.math as tl_math
3+
from ._core import builtin
4+
5+
__all__ = [
6+
"umulhi", "exp", "exp2", "fma", "log", "log2", "cos", "rsqrt", "sin", "sqrt", "sqrt_rn", "abs", "fdiv", "div_rn",
7+
"erf", "floor", "ceil"
8+
]
9+
10+
for name in __all__:
11+
fn = getattr(tl_math, name)
12+
globals()[name] = builtin(fn)

0 commit comments

Comments
 (0)