Skip to content

Commit 4390874

Browse files
authored
[Gluon] Add zeros, zeros_like and full_like (#7151)
1 parent e1fb6f6 commit 4390874

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

python/test/gluon/test_frontend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,3 +881,30 @@ def test_tensor_reshape():
881881
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 2], [2, 4, 4], [4, 1, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1],
882882
[2, 1, 0])
883883
ttgl.static_assert(v.type.layout == expect_layout)
884+
885+
886+
@filecheck_test
887+
@gluon.jit
888+
def test_zeros():
889+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2]
890+
# CHECK: [[BLOCKED2D:#.*]] = #ttg.blocked<{sizePerThread = [1, 2]
891+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
892+
layout_2d: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])
893+
894+
# CHECK: arith.constant dense<0.000000e+00> : tensor<32xf32, [[BLOCKED]]>
895+
a = ttgl.zeros([32], ttgl.float32, layout)
896+
897+
# CHECK: arith.constant dense<7.000000e+00> : tensor<32xf32, [[BLOCKED]]>
898+
ttgl.full_like(a, 7)
899+
900+
# CHECK: arith.constant dense<0.000000e+00> : tensor<32xf32, [[BLOCKED]]>
901+
ttgl.zeros_like(a)
902+
903+
# CHECK: arith.constant dense<0.000000e+00> : tensor<64xf32, [[BLOCKED]]>
904+
ttgl.zeros_like(a, shape=[64])
905+
906+
# CHECK: arith.constant dense<0> : tensor<16x16xi8, [[BLOCKED2D]]>
907+
ttgl.zeros_like(a, shape=[16, 16], dtype=ttgl.int8, layout=layout_2d)
908+
909+
# CHECK: arith.constant dense<7> : tensor<8x8xi16, [[BLOCKED2D]]>
910+
ttgl.full_like(a, 7, shape=[8, 8], dtype=ttgl.int16, layout=layout_2d)

python/triton/experimental/gluon/language/_standard.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton.language.standard as tl_standard
44
from .._runtime import jit
55
from triton import knobs
6+
from . import _core as ttgl
67

78
_IMPORT_FROM_TRITON = [
89
"sum",
@@ -12,10 +13,35 @@
1213
"xor_sum",
1314
]
1415

15-
__all__ = _IMPORT_FROM_TRITON
16+
__all__ = [
17+
"full_like",
18+
"zeros",
19+
"zeros_like",
20+
*_IMPORT_FROM_TRITON,
21+
]
1622

1723
for name in _IMPORT_FROM_TRITON:
1824
# Convert JITFunction -> GluonJitFunction
1925
fn = getattr(tl_standard, name)
2026
assert knobs.runtime.interpret or isinstance(fn, triton.runtime.JITFunction)
2127
globals()[name] = jit(fn.fn)
28+
29+
30+
@jit
31+
def zeros(shape, dtype, layout):
32+
return ttgl.full(shape, 0, dtype, layout)
33+
34+
35+
@jit
36+
def full_like(input, value, shape=None, dtype=None, layout=None):
37+
return ttgl.full(
38+
input.shape if shape is None else shape,
39+
value,
40+
input.dtype if dtype is None else dtype,
41+
input.type.layout if layout is None else layout,
42+
)
43+
44+
45+
@jit
46+
def zeros_like(input, shape=None, dtype=None, layout=None):
47+
return full_like(input, 0, shape=shape, dtype=dtype, layout=layout)

0 commit comments

Comments
 (0)