Skip to content

Commit 243be25

Browse files
authored
[FRONTEND] Fix mismatched type for extern elementwise (#7930)
#7890 forces return type of non-scalar `core.extern_elementwise` to be the same as broadcast arguments, which does not apply for all cases. This PR fixes by using the element type specified. Fixes #7921.
1 parent fff5a2d commit 243be25

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python/test/gluon/test_frontend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,14 @@ def libdevice_kernel():
940940
a = ttgl.full([4, 32], 1, ttgl.float32, layout)
941941
b = ttgl.full([4, 32], 2, ttgl.float32, layout)
942942
c = ttgl.full([4, 32], 4, ttgl.float32, layout)
943+
943944
libdevice.abs(a)
944945
libdevice.fast_dividef(a, b)
945946
libdevice.fma(a, b, c)
946947

948+
libdevice.isnan(a)
949+
libdevice.isinf(a)
950+
947951

948952
@pytest.mark.parametrize("target", ALL_TARGETS)
949953
def test_libdevice(target):
@@ -962,6 +966,14 @@ def test_libdevice(target):
962966
%0 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
963967
%1 = tt.extern_elementwise %cst_0, %cst_2 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
964968
%2 = tt.extern_elementwise %cst_0, %cst_2, %cst_4 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>, tensor<4x32xf32, #blocked>) -> tensor<4x32xf32, #blocked>
969+
%3 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xi32, #blocked>
970+
%c0_i32 = arith.constant 0 : i32
971+
%cst_5 = arith.constant dense<0> : tensor<4x32xi32, #blocked>
972+
%4 = arith.cmpi ne, %3, %cst_5 : tensor<4x32xi32, #blocked>
973+
%5 = tt.extern_elementwise %cst_0 {libname = "", libpath = "", pure = true, symbol = "..."} : (tensor<4x32xf32, #blocked>) -> tensor<4x32xi32, #blocked>
974+
%c0_i32_6 = arith.constant 0 : i32
975+
%cst_7 = arith.constant dense<0> : tensor<4x32xi32, #blocked>
976+
%6 = arith.cmpi ne, %5, %cst_7 : tensor<4x32xi32, #blocked>
965977
tt.return
966978
}
967979
}

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3360,7 +3360,7 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
33603360
dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
33613361
arithmetic_check=arithmetic_check)
33623362
if not all_scalar:
3363-
ret_type = broadcast_arg.type
3363+
ret_type = broadcast_arg.type.with_element_ty(ret_type)
33643364
func = _semantic.builder.create_extern_elementwise
33653365
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
33663366

0 commit comments

Comments
 (0)