66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < executorch/kernels/portable/cpu/scalar_utils.h>
10- #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
11- #include < executorch/kernels/portable/cpu/util/math_util.h>
12- #include < executorch/runtime/kernel/kernel_includes.h>
13- #include < executorch/runtime/platform/assert.h>
14- #include < cmath>
9+ #include < executorch/kernels/portable/cpu/op_div_impl.h>
1510
1611namespace torch {
1712namespace executor {
1813namespace native {
1914
20- namespace {
21-
22- ScalarType get_common_type (ScalarType a_type, ScalarType b_type) {
23- if (isFloatingType (a_type) && isFloatingType (b_type)) {
24- return promoteTypes (a_type, b_type);
25- } else if (isFloatingType (a_type)) {
26- return a_type;
27- } else if (isFloatingType (b_type)) {
28- return b_type;
29- }
30- return ScalarType::Float;
31- }
32-
33- } // namespace
34-
3515Tensor& div_out (
3616 KernelRuntimeContext& ctx,
3717 const Tensor& a,
3818 const Tensor& b,
3919 Tensor& out) {
40- // Common Dtype
41- ScalarType common_type = get_common_type (a.scalar_type (), b.scalar_type ());
42-
43- // Check Dim Order
44- ET_KERNEL_CHECK (
45- ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
46-
47- // Resize
48- ET_KERNEL_CHECK (
49- ctx,
50- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
51- InvalidArgument,
52- out);
53-
54- // Compute Dtype
55- ScalarType compute_type = utils::get_compute_type (common_type);
56-
57- // @lint-ignore CLANGTIDY facebook-hte-CArray
58- static constexpr const char op_name[] = " div.out" ;
59-
60- ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62- [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63- return val_a / val_b;
64- },
65- ctx,
66- a,
67- utils::SupportedTensorDtypes::REALHBBF16,
68- b,
69- utils::SupportedTensorDtypes::REALHBBF16,
70- out,
71- utils::SupportedTensorDtypes::FLOATHBF16);
72- });
73-
74- return out;
20+ return div_out_impl (ctx, a, b, out);
7521}
7622
7723Tensor& div_out_mode (
@@ -80,124 +26,15 @@ Tensor& div_out_mode(
8026 const Tensor& b,
8127 exec_aten::optional<exec_aten::string_view> mode,
8228 Tensor& out) {
83- if (!mode.has_value ()) {
84- return div_out (ctx, a, b, out);
85- }
86-
87- auto mode_val = mode.value ();
88-
89- // Check mode
90- ET_KERNEL_CHECK (
91- ctx, mode_val == " trunc" || mode_val == " floor" , InvalidArgument, out);
92-
93- // Common Dtype
94- ScalarType common_type = promoteTypes (a.scalar_type (), b.scalar_type ());
95-
96- // Check Common Dtype
97- ET_KERNEL_CHECK (
98- ctx,
99- (canCast (common_type, out.scalar_type ()) &&
100- common_type != ScalarType::Bool),
101- InvalidArgument,
102- out);
103-
104- // Check Dim Order
105- ET_KERNEL_CHECK (
106- ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
107-
108- // Resize
109- ET_KERNEL_CHECK (
110- ctx,
111- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
112- InvalidArgument,
113- out);
114-
115- // Compute Dtype
116- ScalarType compute_type = utils::get_compute_type (common_type);
117-
118- // @lint-ignore CLANGTIDY facebook-hte-CArray
119- static constexpr const char op_name[] = " div.out_mode" ;
120-
121- const bool mode_is_trunc = mode_val == " trunc" ;
122- bool div_by_zero_error = false ;
123-
124- ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
126- [mode_is_trunc, &div_by_zero_error](
127- const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
128- if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
129- if (val_b == 0 ) {
130- div_by_zero_error = true ;
131- return static_cast <CTYPE_COMPUTE>(0 );
132- }
133- }
134- CTYPE_COMPUTE value = val_a / val_b;
135- if (mode_is_trunc) {
136- value = std::trunc (value);
137- } else {
138- // We established above that the mode is either trunc or floor, so
139- // it must be floor.
140- value = utils::floor_divide (val_a, val_b);
141- }
142- return value;
143- },
144- ctx,
145- a,
146- utils::SupportedTensorDtypes::REALHBBF16,
147- b,
148- utils::SupportedTensorDtypes::REALHBBF16,
149- out,
150- utils::SupportedTensorDtypes::REALHBF16);
151- });
152-
153- ET_KERNEL_CHECK_MSG (
154- ctx,
155- !div_by_zero_error,
156- InvalidArgument,
157- out,
158- " Div mode operation encountered integer division by zero" );
159-
160- return out;
29+ return div_out_mode_impl (ctx, a, b, mode, out);
16130}
16231
16332Tensor& div_scalar_out (
16433 KernelRuntimeContext& ctx,
16534 const Tensor& a,
16635 const Scalar& b,
16736 Tensor& out) {
168- // Common Dtype
169- ScalarType common_type =
170- isFloatingType (a.scalar_type ()) ? a.scalar_type () : ScalarType::Float;
171-
172- // Check Common Dtype
173- ET_KERNEL_CHECK (ctx, common_type == out.scalar_type (), InvalidArgument, out);
174-
175- // Check Dim Order
176- ET_KERNEL_CHECK (
177- ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
178-
179- // Resize
180- ET_KERNEL_CHECK (
181- ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
182-
183- // Compute Dtype
184- ScalarType compute_type = utils::get_compute_type (common_type);
185-
186- // @lint-ignore CLANGTIDY facebook-hte-CArray
187- static constexpr const char op_name[] = " div.Scalar_out" ;
188-
189- ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190- const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191- utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192- [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
193- ctx,
194- a,
195- utils::SupportedTensorDtypes::REALHBBF16,
196- out,
197- utils::SupportedTensorDtypes::SAME_AS_COMMON);
198- });
199-
200- return out;
37+ return div_scalar_out_impl (ctx, a, b, out);
20138}
20239
20340Tensor& div_scalar_mode_out (
@@ -206,72 +43,7 @@ Tensor& div_scalar_mode_out(
20643 const Scalar& b,
20744 exec_aten::optional<exec_aten::string_view> mode,
20845 Tensor& out) {
209- if (!mode.has_value ()) {
210- return div_scalar_out (ctx, a, b, out);
211- }
212-
213- auto mode_val = mode.value ();
214-
215- // Check mode
216- ET_KERNEL_CHECK (
217- ctx, mode_val == " trunc" || mode_val == " floor" , InvalidArgument, out);
218-
219- // Common Dtype
220- ScalarType common_type = utils::promote_type_with_scalar (a.scalar_type (), b);
221-
222- // Check Common Dtype
223- ET_KERNEL_CHECK (
224- ctx,
225- (canCast (common_type, out.scalar_type ()) &&
226- common_type != ScalarType::Bool),
227- InvalidArgument,
228- out);
229-
230- // Check for intergral division by zero
231- ET_KERNEL_CHECK_MSG (
232- ctx,
233- !(executorch::runtime::isIntegralType (common_type, true ) &&
234- utils::scalar_to<double >(b) == 0 ),
235- InvalidArgument,
236- out,
237- " Div mode operation encountered integer division by zero" );
238-
239- // Check Dim Order
240- ET_KERNEL_CHECK (
241- ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
242-
243- // Resize
244- ET_KERNEL_CHECK (
245- ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
246-
247- // Compute Dtype
248- ScalarType compute_type = utils::get_compute_type (common_type);
249-
250- const bool mode_is_trunc = mode_val == " trunc" ;
251-
252- // @lint-ignore CLANGTIDY facebook-hte-CArray
253- static constexpr const char op_name[] = " div.Scalar_mode_out" ;
254-
255- ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
256- const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
257- utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
258- [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) {
259- CTYPE_COMPUTE value = val_a / val_b;
260- if (mode_is_trunc) {
261- value = std::trunc (value);
262- } else {
263- value = utils::floor_divide (val_a, val_b);
264- }
265- return value;
266- },
267- ctx,
268- a,
269- utils::SupportedTensorDtypes::REALHBBF16,
270- out,
271- utils::SupportedTensorDtypes::REALHBF16);
272- });
273-
274- return out;
46+ return div_scalar_mode_out_impl (ctx, a, b, mode, out);
27547}
27648
27749} // namespace native
0 commit comments