diff --git a/src/infiniop/ops/add/bang/add_bang.mlu b/src/infiniop/ops/add/bang/add_bang.mlu index 120af568b..4925a604c 100644 --- a/src/infiniop/ops/add/bang/add_bang.mlu +++ b/src/infiniop/ops/add/bang/add_bang.mlu @@ -31,7 +31,7 @@ infiniStatus_t Descriptor::create( const auto &a_shape = a_desc->shape(); const auto &b_shape = b_desc->shape(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_I32, INFINI_DTYPE_I64); CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); @@ -59,6 +59,10 @@ infiniStatus_t Descriptor::calculate( return _device_info->calculate(_info, workspace, output, inputs, queue); case INFINI_DTYPE_F32: return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I32: + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I64: + return _device_info->calculate(_info, workspace, output, inputs, queue); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add/bang/add_bang_internal.mlu b/src/infiniop/ops/add/bang/add_bang_internal.mlu index 475c13e64..a6d2d24cb 100644 --- a/src/infiniop/ops/add/bang/add_bang_internal.mlu +++ b/src/infiniop/ops/add/bang/add_bang_internal.mlu @@ -8,7 +8,7 @@ public: static constexpr size_t num_inputs = 2; template __mlu_device__ void operator()(T *out, const T *a, const T *b, size_t num_elements) const { - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { __bang_add(out, a, b, num_elements); } else { out = a + b; @@ -21,5 +21,7 @@ LAUNCH_ELEMENTWISE_KERNEL_IMPL(Add, AddOp) LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, half) LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, bfloat16_t) LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, float) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int32_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Add, int64_t) #endif // __ADD_BANG_INTERNAL_H__