Skip to content

Commit a7a89c7

Browse files
[FRONTEND] Refactor unsplat to use new op (#7586)
Co-authored-by: peterbell10 <[email protected]>
1 parent 6415039 commit a7a89c7

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
529529
GenericOpPattern<triton::IntToPtrOp>,
530530
GenericOpPattern<triton::PtrToIntOp>,
531531
GenericOpPattern<triton::SplatOp>,
532+
GenericOpPattern<triton::UnsplatOp>,
532533
GenericOpPattern<triton::AddPtrOp>,
533534
TritonBroadcastPattern,
534535
TritonCatPattern,

python/src/ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,10 @@ void init_triton_ir(py::module &&m) {
14951495
[](TritonOpBuilder &self, Type &retTy, Value &arg) -> Value {
14961496
return self.createOrFold<SplatOp>(retTy, arg);
14971497
})
1498+
.def("create_unsplat",
1499+
[](TritonOpBuilder &self, Value &arg) -> Value {
1500+
return self.createOrFold<UnsplatOp>(arg);
1501+
})
14981502
// // atomic
14991503
.def("create_atomic_cas",
15001504
[](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val,

python/triton/language/core.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass
1111
import builtins
1212
from .. import knobs
13-
from ..runtime.jit import jit, JITFunction
13+
from ..runtime.jit import JITFunction
1414
import inspect
1515

1616
from .._C.libtriton import ir
@@ -1831,11 +1831,6 @@ def join(a, b, _semantic=None):
18311831
return _semantic.join(a, b)
18321832

18331833

1834-
@jit
1835-
def _take_first(a, b):
1836-
return a
1837-
1838-
18391834
def _unsplat(x, _semantic=None, _generator=None):
18401835
"""
18411836
Convert a single-element tensor to a scalar.
@@ -1846,10 +1841,7 @@ def _unsplat(x, _semantic=None, _generator=None):
18461841
for d in x.shape:
18471842
numel *= d
18481843
assert numel == 1, "can only unsplat single-element tensors"
1849-
if len(x.shape) >= 2:
1850-
x = _semantic.reshape(x, [1])
1851-
x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator))
1852-
return x
1844+
return _semantic.unsplat(x)
18531845

18541846

18551847
@_tensor_member_fn

python/triton/language/semantic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,9 @@ def splat(self, value: TensorTy, shape: List[int]) -> TensorTy:
619619
ret_ty = tl.block_type(value.dtype, shape)
620620
return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
621621

622+
def unsplat(self, value: TensorTy) -> TensorTy:
623+
return self.tensor(self.builder.create_unsplat(value.handle), value.dtype)
624+
622625
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
623626
numel = 1
624627
for s in dst_shape:

python/triton/runtime/interpreter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,9 @@ def create_splat(self, ret_ty, arg):
663663
else: # scalar
664664
return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
665665

666+
def create_unsplat(self, arg):
667+
return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
668+
666669
def create_atomic_cas(self, ptr, cmp, val, sem, scope):
667670
if sem not in self.ir_sem_to_interpreter_sem:
668671
raise ValueError(f"unsupported semantic {sem}")

0 commit comments

Comments
 (0)