Skip to content

Commit 13bae05

Browse files
authored
[Gluon] Add split, join and reshape on tensor (#7122)
1 parent e594544 commit 13bae05

File tree

3 files changed

+67
-13
lines changed

3 files changed

+67
-13
lines changed

python/test/gluon/test_frontend.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,36 @@ def test_tensor_permute():
844844
res = ttgl.permute(a, [1, 0])
845845
permuted_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 1], [8, 4], [1, 4], [0, 1], [1, 1], [1, 1], [1, 0])
846846
ttgl.static_assert(permuted_layout == res.type.layout)
847+
848+
849+
@filecheck_test
850+
@gluon.jit
851+
def test_split_join():
852+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
853+
# CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
854+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
855+
a = ttgl.full([128], 1, ttgl.int32, layout)
856+
b = ttgl.full([128], 2, ttgl.int32, layout)
857+
# CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
858+
res = ttgl.join(a, b)
859+
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 2], [32, 1], [4, 1], [1, 0], [1, 1], [1, 1], [1, 0])
860+
ttgl.static_assert(res.type.layout == expect_layout)
861+
862+
# CHECK: tt.split {{.*}} : tensor<128x2xi32, [[BLOCKED1]]> -> tensor<128xi32, [[BLOCKED]]>
863+
c, d = ttgl.split(res)
864+
ttgl.static_assert(c.type.layout == layout)
865+
ttgl.static_assert(d.type.layout == layout)
866+
867+
868+
@filecheck_test
869+
@gluon.jit
870+
def test_tensor_reshape():
871+
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
872+
# CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 2], threadsPerWarp = [2, 4, 4], warpsPerCTA = [4, 1, 1], order = [2, 1, 0]}>
873+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
874+
a = ttgl.full([256], 1, ttgl.int32, layout)
875+
# CHECK: tt.reshape {{.*}} : tensor<256xi32, [[BLOCKED]]> -> tensor<8x4x8xi32, [[BLOCKED1]]>
876+
v = a.reshape([8, 4, 8])
877+
expect_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1, 2], [2, 4, 4], [4, 1, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1],
878+
[2, 1, 0])
879+
ttgl.static_assert(v.type.layout == expect_layout)

python/triton/experimental/gluon/language/_core.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,20 @@
4242
)
4343

4444
_IMPORT_FROM_TRITON: List[str] = [
45-
"expand_dims", # NOQA: F822
46-
"load", # NOQA: F822
47-
"program_id", # NOQA: F822
48-
"reduce", # NOQA: F822
49-
"static_assert", # NOQA: F822
50-
"store", # NOQA: F822
51-
"to_tensor", # NOQA: F822
52-
"where", # NOQA: F822
53-
"maximum", # NOQA: F822
54-
"minimum", # NOQA: F822
45+
"expand_dims",
46+
"join",
47+
"load",
48+
"maximum",
49+
"minimum",
5550
"permute",
51+
"program_id",
52+
"reduce",
53+
"reshape",
54+
"split",
55+
"static_assert",
56+
"store",
57+
"to_tensor",
58+
"where",
5659
]
5760

5861
__all__ = [

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class GluonSemantic(TritonSemantic[TensorTy]):
2222
def __init__(self, builder: GluonOpBuilder):
2323
self.builder = builder
2424

25+
def _wrap_tensor_infer_layout(self, tensor):
26+
ty = ttgl.distributed_type(tensor.type.scalar, tensor.shape,
27+
self.builder.get_gluon_layout_from_tensor(tensor.handle))
28+
return self.tensor(tensor.handle, ty)
29+
2530
def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
2631
if len(lhs_shape) != len(rhs_shape):
2732
raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")
@@ -57,11 +62,19 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
5762
handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
5863
return self.tensor(handle, ret_ty)
5964

65+
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
66+
a, b = self.broadcast_impl_value(a, b)
67+
_check(a.shape != [], "Cannot join scalars in gluon")
68+
value = super().join(a, b)
69+
return self._wrap_tensor_infer_layout(value)
70+
71+
def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
72+
lhs, rhs = super().split(a)
73+
return self._wrap_tensor_infer_layout(lhs), self._wrap_tensor_infer_layout(rhs)
74+
6075
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
6176
value = super().permute(input, dims)
62-
layout = self.builder.get_gluon_layout_from_tensor(value.handle)
63-
res_ty = ttgl.distributed_type(value.type.scalar, value.shape, layout)
64-
return self.tensor(value.handle, res_ty)
77+
return self._wrap_tensor_infer_layout(value)
6578

6679
def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
6780
_check(isinstance(input.type, ttgl.distributed_type),
@@ -106,6 +119,11 @@ def arange(self, start, end, layout):
106119
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
107120
return super().arange(start, end, ret_ty=ret_ty)
108121

122+
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
123+
_check(not can_reorder, "can_reorder is not supported in gluon")
124+
value = super().reshape(input, dst_shape, can_reorder)
125+
return self._wrap_tensor_infer_layout(value)
126+
109127
def splat(self, value, shape, layout):
110128
ret_ty = ttgl.distributed_type(value.dtype, shape, layout)
111129
handle = self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle)

0 commit comments

Comments
 (0)