1010#include < ATen/cpu/vec/vec.h>
1111#include < executorch/kernels/optimized/cpu/binary_ops.h>
1212#include < executorch/kernels/portable/cpu/scalar_utils.h>
13- #include < executorch/kernels/portable/cpu/util/broadcast_util .h>
13+ #include < executorch/kernels/portable/cpu/util/elementwise_util .h>
1414#include < executorch/runtime/kernel/kernel_includes.h>
1515#include < executorch/runtime/platform/assert.h>
1616
@@ -20,7 +20,7 @@ namespace native {
2020
2121namespace {
2222
23- ScalarType get_compute_type (ScalarType a_type, ScalarType b_type) {
23+ ScalarType get_common_type (ScalarType a_type, ScalarType b_type) {
2424 ET_CHECK (
2525 !isComplexType (a_type) && !isQIntType (a_type) && !isBitsType (a_type));
2626 ET_CHECK (
@@ -43,14 +43,27 @@ Tensor& opt_div_out(
4343 const Tensor& a,
4444 const Tensor& b,
4545 Tensor& out) {
46- (void )ctx;
46+ // Check Dim Order
47+ ET_KERNEL_CHECK (
48+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
49+
50+ // Resize
51+ ET_KERNEL_CHECK (
52+ ctx,
53+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
54+ InvalidArgument,
55+ out);
56+
57+ // @lint-ignore CLANGTIDY facebook-hte-CArray
58+ static constexpr const char op_name[] = " div.out" ;
4759
4860 ScalarType a_type = a.scalar_type ();
4961 ScalarType b_type = b.scalar_type ();
5062 ScalarType out_type = out.scalar_type ();
5163
5264 if (a.numel () == 1 || b.numel () == 1 ) {
53- if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
65+ if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
66+ a_type != ScalarType::BFloat16) {
5467 const Tensor* tensor;
5568 const Tensor* scalar;
5669 ScalarType tensor_type;
@@ -66,13 +79,8 @@ Tensor& opt_div_out(
6679 scalar = &b;
6780 scalar_type = b_type;
6881 }
69- ET_KERNEL_CHECK (
70- ctx,
71- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
72- InvalidArgument,
73- out);
74- ET_SWITCH_REALB_TYPES (tensor_type, ctx, " div.out" , CTYPE, [&]() {
75- ET_SWITCH_REALB_TYPES (scalar_type, ctx, " div.out" , CTYPE_SCALAR, [&]() {
82+ ET_SWITCH_REALB_TYPES (tensor_type, ctx, op_name, CTYPE, [&]() {
83+ ET_SWITCH_REALB_TYPES (scalar_type, ctx, op_name, CTYPE_SCALAR, [&]() {
7684 CTYPE_SCALAR scalar_val = *scalar->const_data_ptr <CTYPE_SCALAR>();
7785 CTYPE scalar_casted = static_cast <CTYPE>(scalar_val);
7886
@@ -101,16 +109,7 @@ Tensor& opt_div_out(
101109
102110 auto selected_optimized_path = select_optimized_path (a, b, out);
103111 if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
104- // Resize for dynamic shape
105- auto error = resize_tensor (out, a.sizes ());
106- ET_KERNEL_CHECK_MSG (
107- ctx,
108- error == Error::Ok,
109- InvalidArgument,
110- out,
111- " Failed to resize output tensor." );
112-
113- ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " div.out" , CTYPE, [&]() {
112+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name, CTYPE, [&]() {
114113 using Vec = at::vec::Vectorized<CTYPE>;
115114 at::vec::map2<CTYPE>(
116115 [](Vec x, Vec y) { return x / y; },
@@ -122,7 +121,7 @@ Tensor& opt_div_out(
122121 } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
123122 // Reason for using alpha is becasuse handle_broadcast_elementwise
124123 // is used for add and sub as well:
125- ET_SWITCH_REALB_TYPES (out_type, ctx, " div.out " , CTYPE, [&]() {
124+ ET_SWITCH_REALB_TYPES (out_type, ctx, op_name , CTYPE, [&]() {
126125 if (selected_optimized_path ==
127126 ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
128127 selected_optimized_path ==
@@ -139,33 +138,21 @@ Tensor& opt_div_out(
139138 }
140139 });
141140 } else {
142- ScalarType common_type = get_compute_type (a_type, b_type);
143- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
144-
145- ET_KERNEL_CHECK (
146- ctx,
147- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
148- InvalidArgument,
149- out);
150-
151- ET_SWITCH_REALB_TYPES (a_type, ctx, " div.out" , CTYPE_A, [&]() {
152- ET_SWITCH_REALB_TYPES (b_type, ctx, " div.out" , CTYPE_B, [&]() {
153- ET_SWITCH_REALB_TYPES (common_type, ctx, " div.out" , CTYPE_IN, [&]() {
154- ET_SWITCH_REALB_TYPES (out_type, ctx, " div.out" , CTYPE_OUT, [&]() {
155- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
156- [](const CTYPE_A val_a, const CTYPE_B val_b) {
157- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
158- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
159- CTYPE_IN value = a_casted / b_casted;
160-
161- return static_cast <CTYPE_OUT>(value);
162- },
163- a,
164- b,
165- out);
166- });
167- });
168- });
141+ ScalarType common_type = get_common_type (a.scalar_type (), b.scalar_type ());
142+ ScalarType compute_type = utils::get_compute_type (common_type);
143+
144+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
145+ utils::apply_bitensor_elementwise_fn<
146+ CTYPE_COMPUTE,
147+ op_name,
148+ utils::SupportedTensorDtypes::FLOATHBF16>(
149+ [](const auto val_a, const auto val_b) { return val_a / val_b; },
150+ ctx,
151+ a,
152+ utils::SupportedTensorDtypes::REALHBBF16,
153+ b,
154+ utils::SupportedTensorDtypes::REALHBBF16,
155+ out);
169156 });
170157 }
171158
@@ -177,63 +164,57 @@ Tensor& opt_div_scalar_out(
177164 const Tensor& a,
178165 const Scalar& b,
179166 Tensor& out) {
180- (void )ctx;
181-
182167 ScalarType a_type = a.scalar_type ();
183168 ScalarType b_type = utils::get_scalar_dtype (b);
184169 ScalarType common_type = isFloatingType (a_type) ? a_type : ScalarType::Float;
185170 ScalarType out_type = out.scalar_type ();
186171
187- ET_CHECK (common_type == out_type);
188-
189- // Resize for dynamic shape
190- auto error = resize_tensor (out, a.sizes ());
191- ET_CHECK_MSG (error == Error::Ok, " Failed to resize output tensor." );
192-
193- if (a_type == common_type && a_type == out_type) {
194- ET_SWITCH_REAL_TYPES (a_type, ctx, " div.Scalar_out" , CTYPE, [&]() {
195- ET_SWITCH_REAL_TYPES_AND (
196- Bool, b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
197- CTYPE_B b_val;
198- ET_EXTRACT_SCALAR (b, b_val);
199- CTYPE b_casted = static_cast <CTYPE>(b_val);
200-
201- using Vec = at::vec::Vectorized<CTYPE>;
202- Vec inv_b_casted_vec (CTYPE (1 ) / b_casted);
203- at::vec::map<CTYPE>(
204- [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
205- out.mutable_data_ptr <CTYPE>(),
206- a.const_data_ptr <CTYPE>(),
207- out.numel ());
208- });
172+ // Check Common Dtype
173+ ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
174+
175+ // Check Dim Order
176+ ET_KERNEL_CHECK (
177+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
178+
179+ // Resize
180+ ET_KERNEL_CHECK (
181+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
182+
183+ // @lint-ignore CLANGTIDY facebook-hte-CArray
184+ static constexpr const char op_name[] = " div.Scalar_out" ;
185+
186+ if (a_type == common_type && a_type == out_type &&
187+ a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
188+ ET_SWITCH_REAL_TYPES (a_type, ctx, op_name, CTYPE, [&]() {
189+ ET_SWITCH_REALB_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
190+ CTYPE_B b_val;
191+ ET_EXTRACT_SCALAR (b, b_val);
192+ CTYPE b_casted = static_cast <CTYPE>(b_val);
193+
194+ using Vec = at::vec::Vectorized<CTYPE>;
195+ Vec inv_b_casted_vec (CTYPE (1 ) / b_casted);
196+ at::vec::map<CTYPE>(
197+ [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
198+ out.mutable_data_ptr <CTYPE>(),
199+ a.const_data_ptr <CTYPE>(),
200+ out.numel ());
201+ });
209202 });
210203 } else {
211- ET_SWITCH_REAL_TYPES_AND (
212- Bool, a_type, ctx, " div.Scalar_out" , CTYPE_A, [&]() {
213- ET_SWITCH_REAL_TYPES_AND (
214- Bool, b_type, ctx, " div.Scalar_out" , CTYPE_B, [&]() {
215- ET_SWITCH_REAL_TYPES (
216- common_type, ctx, " div.Scalar_out" , CTYPE_IN, [&]() {
217- ET_SWITCH_REAL_TYPES (
218- out_type, ctx, " div.Scalar_out" , CTYPE_OUT, [&]() {
219- CTYPE_B b_val;
220- ET_EXTRACT_SCALAR (b, b_val);
221- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
222- CTYPE_IN inv_b_casted = CTYPE_IN (1 ) / b_casted;
223-
224- const size_t n = a.numel ();
225- const CTYPE_A* a_data = a.const_data_ptr <CTYPE_A>();
226- CTYPE_OUT* out_data =
227- out.mutable_data_ptr <CTYPE_OUT>();
228- for (auto i = 0 ; i < n; ++i) {
229- out_data[i] = static_cast <CTYPE_OUT>(
230- static_cast <CTYPE_IN>(a_data[i]) *
231- inv_b_casted);
232- }
233- });
234- });
235- });
236- });
204+ ScalarType compute_type = utils::get_compute_type (common_type);
205+
206+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
207+ const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
208+ utils::apply_unitensor_elementwise_fn<
209+ CTYPE_COMPUTE,
210+ op_name,
211+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
212+ [val_b](const auto val_a) { return val_a / val_b; },
213+ ctx,
214+ a,
215+ utils::SupportedTensorDtypes::REALHBBF16,
216+ out);
217+ });
237218 }
238219
239220 return out;
0 commit comments