Skip to content

Commit 3b00caa

Browse files
authored
fix error messages in gluon gather (triton-lang#8108)
1 parent d6a7139 commit 3b00caa

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
8686

8787
def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
8888
a, b = self.broadcast_impl_value(a, b)
89-
_check(a.shape != [], "Cannot join scalars in gluon")
89+
_check(a.shape != [], lambda: "Cannot join scalars in gluon")
9090
value = super().join(a, b)
9191
return self._wrap_tensor_infer_layout(value)
9292

@@ -151,7 +151,7 @@ def arange(self, start, end, layout):
151151
return super().arange(start, end, ret_ty=ret_ty)
152152

153153
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool):
154-
_check(not can_reorder, "can_reorder is not supported in gluon")
154+
_check(not can_reorder, lambda: "can_reorder is not supported in gluon")
155155
value = super().reshape(input, dst_shape, can_reorder)
156156
return self._wrap_tensor_infer_layout(value)
157157

@@ -365,7 +365,7 @@ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
365365
_check(index.type.scalar.is_int(), lambda: f"expected integer scalar type but got: {index.type.scalar!r}")
366366

367367
rank = len(src.type.shape)
368-
_check(len(index.type.shape) == rank, "source and index tensors must have the same rank")
368+
_check(len(index.type.shape) == rank, lambda: "source and index tensors must have the same rank")
369369
_check(-rank <= axis < rank, lambda: f"gather axis {axis} must be < source rank ({rank})")
370370
if axis < 0:
371371
axis += rank

0 commit comments

Comments
 (0)