66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < executorch/kernels/portable/cpu/util/broadcast_util .h>
9+ #include < executorch/kernels/portable/cpu/util/elementwise_util .h>
1010#include < executorch/kernels/portable/cpu/util/math_util.h>
1111#include < executorch/runtime/kernel/kernel_includes.h>
1212#include < executorch/runtime/platform/assert.h>
@@ -17,106 +17,61 @@ namespace torch {
1717namespace executor {
1818namespace native {
1919
20- using Tensor = exec_aten::Tensor;
21- using ScalarType = exec_aten::ScalarType;
22-
23- namespace {
24- template <
25- bool can_cast,
26- typename CTYPE_A,
27- typename CTYPE_B,
28- typename CTYPE_IN,
29- typename CTYPE_OUT>
30- struct FloorDivideInner ;
31-
32- template <
33- typename CTYPE_A,
34- typename CTYPE_B,
35- typename CTYPE_IN,
36- typename CTYPE_OUT>
37- struct FloorDivideInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38- static void
39- run (const Tensor& a, const Tensor& b, Tensor& out, bool & div_by_zero_error) {
40- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
41- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
42- [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
43- if (is_integral_type<CTYPE_IN, /* includeBool=*/ true >::value) {
44- if (val_b == 0 ) {
45- div_by_zero_error = true ;
46- return static_cast <CTYPE_OUT>(0 );
47- }
48- }
49- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
50- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
51- CTYPE_IN value = utils::floor_divide<CTYPE_IN>(a_casted, b_casted);
52-
53- return static_cast <CTYPE_OUT>(value);
54- },
55- a,
56- b,
57- out);
58- }
59- };
60-
61- struct ReportCanCastBug {
62- static void run (const Tensor&, const Tensor&, Tensor&, bool &) {
63- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
64- }
65- };
66-
67- template <
68- typename CTYPE_A,
69- typename CTYPE_B,
70- typename CTYPE_IN,
71- typename CTYPE_OUT>
72- struct FloorDivideInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
73- : public ReportCanCastBug {};
74-
75- } // namespace
76-
7720Tensor& floor_divide_out (
7821 KernelRuntimeContext& ctx,
7922 const Tensor& a,
8023 const Tensor& b,
8124 Tensor& out) {
25+ // Common Dtype
26+ ScalarType common_type = promoteTypes (a.scalar_type (), b.scalar_type ());
27+
28+ // Check Common Dtype
8229 ET_KERNEL_CHECK (
8330 ctx,
84- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
31+ (canCast (common_type, out.scalar_type ()) &&
32+ common_type != ScalarType::Bool),
8533 InvalidArgument,
8634 out);
8735
88- ET_KERNEL_CHECK (ctx, tensor_is_real_type (out), InvalidArgument, out);
89-
36+ // Check Dim Order
9037 ET_KERNEL_CHECK (
9138 ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
9239
93- ScalarType a_type = a.scalar_type ();
94- ScalarType b_type = b.scalar_type ();
95- ScalarType common_type = promoteTypes (a_type, b_type);
96- ScalarType out_type = out.scalar_type ();
40+ // Resize
41+ ET_KERNEL_CHECK (
42+ ctx,
43+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
44+ InvalidArgument,
45+ out);
46+
47+ // Compute Dtype
48+ ScalarType compute_type = utils::get_compute_type (common_type);
9749
98- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
50+ // @lint-ignore CLANGTIDY facebook-hte-CArray
51+ static constexpr const char op_name[] = " floor_divide.out" ;
9952
100- auto div_by_zero_error = false ;
53+ bool div_by_zero_error = false ;
10154
102- ET_SWITCH_REAL_TYPES_AND (
103- Bool, a_type, ctx, " floor_divide.out" , CTYPE_A, [&]() {
104- ET_SWITCH_REAL_TYPES_AND (
105- Bool, b_type, ctx, " floor_divide.out" , CTYPE_B, [&]() {
106- using CTYPE_IN = typename torch::executor::
107- promote_types<CTYPE_A, CTYPE_B>::type;
108- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
109- ET_SWITCH_REAL_TYPES (
110- out_type, ctx, " floor_divide.out" , CTYPE_OUT, [&]() {
111- FloorDivideInner<
112- can_cast<CTYPE_IN, CTYPE_OUT>::value,
113- CTYPE_A,
114- CTYPE_B,
115- CTYPE_IN,
116- CTYPE_OUT>::run (a, b, out, div_by_zero_error);
117- });
118- });
119- });
55+ ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57+ [&div_by_zero_error](
58+ const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
59+ if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
60+ if (val_b == 0 ) {
61+ div_by_zero_error = true ;
62+ return static_cast <CTYPE_COMPUTE>(0 );
63+ }
64+ }
65+ return utils::floor_divide (val_a, val_b);
66+ },
67+ ctx,
68+ a,
69+ utils::SupportedTensorDtypes::REALHBBF16,
70+ b,
71+ utils::SupportedTensorDtypes::REALHBBF16,
72+ out,
73+ utils::SupportedTensorDtypes::REALHBF16);
74+ });
12075
12176 ET_KERNEL_CHECK_MSG (
12277 ctx,
0 commit comments