Skip to content

Commit 4334cec

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Integer division binary ops: floor_divide, fmod, remainder (#6012)
Summary: Pull Request resolved: #6012 - remainder: 386 K -> 15 K - fmod: 317 K -> 14 K - floor_divide: 255 K -> 11 K ghstack-source-id: 246985124 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63909727 fbshipit-source-id: 44d1fd575a3723df5982833abf5f9528b307a4bf
1 parent 9b1d333 commit 4334cec

File tree

4 files changed

+215
-347
lines changed

4 files changed

+215
-347
lines changed

kernels/portable/cpu/op_floor_divide.cpp

Lines changed: 40 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1010
#include <executorch/kernels/portable/cpu/util/math_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/platform/assert.h>
@@ -17,106 +17,61 @@ namespace torch {
1717
namespace executor {
1818
namespace native {
1919

20-
using Tensor = exec_aten::Tensor;
21-
using ScalarType = exec_aten::ScalarType;
22-
23-
namespace {
24-
template <
25-
bool can_cast,
26-
typename CTYPE_A,
27-
typename CTYPE_B,
28-
typename CTYPE_IN,
29-
typename CTYPE_OUT>
30-
struct FloorDivideInner;
31-
32-
template <
33-
typename CTYPE_A,
34-
typename CTYPE_B,
35-
typename CTYPE_IN,
36-
typename CTYPE_OUT>
37-
struct FloorDivideInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38-
static void
39-
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
40-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
41-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
42-
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
43-
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
44-
if (val_b == 0) {
45-
div_by_zero_error = true;
46-
return static_cast<CTYPE_OUT>(0);
47-
}
48-
}
49-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
50-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
51-
CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);
52-
53-
return static_cast<CTYPE_OUT>(value);
54-
},
55-
a,
56-
b,
57-
out);
58-
}
59-
};
60-
61-
struct ReportCanCastBug {
62-
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
63-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
64-
}
65-
};
66-
67-
template <
68-
typename CTYPE_A,
69-
typename CTYPE_B,
70-
typename CTYPE_IN,
71-
typename CTYPE_OUT>
72-
struct FloorDivideInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
73-
: public ReportCanCastBug {};
74-
75-
} // namespace
76-
7720
Tensor& floor_divide_out(
7821
KernelRuntimeContext& ctx,
7922
const Tensor& a,
8023
const Tensor& b,
8124
Tensor& out) {
25+
// Common Dtype
26+
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27+
28+
// Check Common Dtype
8229
ET_KERNEL_CHECK(
8330
ctx,
84-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
31+
(canCast(common_type, out.scalar_type()) &&
32+
common_type != ScalarType::Bool),
8533
InvalidArgument,
8634
out);
8735

88-
ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out);
89-
36+
// Check Dim Order
9037
ET_KERNEL_CHECK(
9138
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
9239

93-
ScalarType a_type = a.scalar_type();
94-
ScalarType b_type = b.scalar_type();
95-
ScalarType common_type = promoteTypes(a_type, b_type);
96-
ScalarType out_type = out.scalar_type();
40+
// Resize
41+
ET_KERNEL_CHECK(
42+
ctx,
43+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44+
InvalidArgument,
45+
out);
46+
47+
// Compute Dtype
48+
ScalarType compute_type = utils::get_compute_type(common_type);
9749

98-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
50+
// @lint-ignore CLANGTIDY facebook-hte-CArray
51+
static constexpr const char op_name[] = "floor_divide.out";
9952

100-
auto div_by_zero_error = false;
53+
bool div_by_zero_error = false;
10154

102-
ET_SWITCH_REAL_TYPES_AND(
103-
Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() {
104-
ET_SWITCH_REAL_TYPES_AND(
105-
Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() {
106-
using CTYPE_IN = typename torch::executor::
107-
promote_types<CTYPE_A, CTYPE_B>::type;
108-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
109-
ET_SWITCH_REAL_TYPES(
110-
out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() {
111-
FloorDivideInner<
112-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
113-
CTYPE_A,
114-
CTYPE_B,
115-
CTYPE_IN,
116-
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
117-
});
118-
});
119-
});
55+
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57+
[&div_by_zero_error](
58+
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
60+
if (val_b == 0) {
61+
div_by_zero_error = true;
62+
return static_cast<CTYPE_COMPUTE>(0);
63+
}
64+
}
65+
return utils::floor_divide(val_a, val_b);
66+
},
67+
ctx,
68+
a,
69+
utils::SupportedTensorDtypes::REALHBBF16,
70+
b,
71+
utils::SupportedTensorDtypes::REALHBBF16,
72+
out,
73+
utils::SupportedTensorDtypes::REALHBF16);
74+
});
12075

12176
ET_KERNEL_CHECK_MSG(
12277
ctx,

0 commit comments

Comments
 (0)