99#include < cmath>
1010
1111#include < executorch/kernels/portable/cpu/scalar_utils.h>
12- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13- #include < executorch/kernels/portable/cpu/util/functional_util.h>
14- #include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1513#include < executorch/runtime/kernel/kernel_includes.h>
1614
1715namespace torch {
1816namespace executor {
1917namespace native {
2018
21- using Tensor = exec_aten::Tensor;
22-
23- namespace {
24- template <
25- bool can_cast,
26- typename CTYPE_A,
27- typename CTYPE_B,
28- typename CTYPE_IN,
29- typename CTYPE_OUT>
30- struct PowInner ;
31-
32- template <
33- typename CTYPE_A,
34- typename CTYPE_B,
35- typename CTYPE_IN,
36- typename CTYPE_OUT>
37- struct PowInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38- static void run (const Tensor& a, const Tensor& b, Tensor& out) {
39- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41- [](const CTYPE_A val_a, const CTYPE_B val_b) {
42- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
43- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
44- CTYPE_IN value = std::pow (a_casted, b_casted);
45- return static_cast <CTYPE_OUT>(value);
46- },
47- a,
48- b,
49- out);
50- }
51- };
52-
53- struct ReportCanCastBug {
54- static void run (const Tensor&, const Tensor&, Tensor&) {
55- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56- }
57- };
58-
59- template <
60- typename CTYPE_A,
61- typename CTYPE_B,
62- typename CTYPE_IN,
63- typename CTYPE_OUT>
64- struct PowInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65- : public ReportCanCastBug {};
66-
67- } // namespace
68-
6919Tensor& pow_Tensor_Tensor_out (
7020 KernelRuntimeContext& ctx,
7121 const Tensor& a,
7222 const Tensor& b,
7323 Tensor& out) {
74- // Determine output size and resize for dynamic shapes
24+ // Common Dtype
25+ ScalarType common_type = promoteTypes (a.scalar_type (), b.scalar_type ());
26+
27+ // Check Common Dtype
7528 ET_KERNEL_CHECK (
7629 ctx,
77- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
30+ (canCast (common_type, out.scalar_type ()) &&
31+ common_type != ScalarType::Bool),
7832 InvalidArgument,
7933 out);
8034
81- ScalarType a_type = a.scalar_type ();
82- ScalarType b_type = b.scalar_type ();
83- ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
84- ScalarType out_type = out.scalar_type ();
35+ // Check Dim Order
36+ ET_KERNEL_CHECK (
37+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
8538
39+ // Resize
8640 ET_KERNEL_CHECK (
87- ctx, common_type != exec_aten::ScalarType::Bool, InvalidArgument, out);
88- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
89-
90- ET_SWITCH_REALHB_TYPES (a_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_A, [&]() {
91- ET_SWITCH_REALHB_TYPES (
92- b_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_B, [&]() {
93- using CTYPE_IN = typename torch::executor::
94- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
95- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
96- ET_SWITCH_REALH_TYPES (
97- out_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_OUT, [&]() {
98- PowInner<
99- !std::is_same<CTYPE_IN, bool >::value &&
100- can_cast<CTYPE_IN, CTYPE_OUT>::value,
101- CTYPE_A,
102- CTYPE_B,
103- CTYPE_IN,
104- CTYPE_OUT>::run (a, b, out);
105- });
106- });
41+ ctx,
42+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
43+ InvalidArgument,
44+ out);
45+
46+ // Compute Dtype
47+ ScalarType compute_type = utils::get_compute_type (common_type);
48+ if (compute_type != ScalarType::Float) {
49+ compute_type = ScalarType::Double;
50+ }
51+
52+ static constexpr const char op_name[] = " pow.Tensor_Tensor_out" ;
53+
54+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56+ [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57+ return std::pow (val_a, val_b);
58+ },
59+ ctx,
60+ a,
61+ utils::SupportedTensorDtypes::REALHBBF16,
62+ b,
63+ utils::SupportedTensorDtypes::REALHBBF16,
64+ out,
65+ utils::SupportedTensorDtypes::REALHBF16);
10766 });
10867
10968 return out;
@@ -114,51 +73,44 @@ Tensor& pow_Tensor_Scalar_out(
11473 const Tensor& a,
11574 const Scalar& b,
11675 Tensor& out) {
117- (void )ctx;
76+ // Common Dtype
77+ ScalarType common_type = utils::promote_type_with_scalar (a.scalar_type (), b);
11878
119- // Resize for dynamic shape
120- ET_KERNEL_CHECK_MSG (
79+ // Check Common Dtype
80+ ET_KERNEL_CHECK (
12181 ctx,
122- resize_tensor (out, a.sizes ()) == Error::Ok,
82+ (canCast (common_type, out.scalar_type ()) &&
83+ common_type != ScalarType::Bool),
12384 InvalidArgument,
124- out,
125- " Failed to resize output tensor." );
85+ out);
12686
127- ScalarType a_type = a.scalar_type ();
128- ScalarType b_type = utils::get_scalar_dtype (b);
129- ScalarType common_type =
130- utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
131- ScalarType out_type = out.scalar_type ();
87+ // Check Dim Order
88+ ET_KERNEL_CHECK (
89+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
13290
133- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
91+ // Resize
92+ ET_KERNEL_CHECK (
93+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
13494
135- if (common_type == ScalarType::Half) {
136- common_type = ScalarType::Float;
95+ // Compute Dtype
96+ ScalarType compute_type = utils::get_compute_type (common_type);
97+ if (compute_type != ScalarType::Float) {
98+ compute_type = ScalarType::Double;
13799 }
138100
139- ET_SWITCH_REALHB_TYPES (a_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_A, [&]() {
140- ET_SWITCH_SCALAR_OBJ_TYPES (
141- b_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_B, [&]() {
142- ET_SWITCH_REAL_TYPES (
143- common_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_IN, [&]() {
144- ET_SWITCH_REALH_TYPES (
145- out_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_OUT, [&]() {
146- CTYPE_B val_b = 0 ;
147- utils::extract_scalar (b, &val_b);
148- apply_unary_map_fn (
149- [val_b](const CTYPE_A val_a) {
150- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
151- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
152- CTYPE_IN value = std::pow (a_casted, b_casted);
153-
154- return static_cast <CTYPE_OUT>(value);
155- },
156- a.const_data_ptr <CTYPE_A>(),
157- out.mutable_data_ptr <CTYPE_OUT>(),
158- out.numel ());
159- });
160- });
161- });
101+ static constexpr const char op_name[] = " pow.Tensor_Scalar_out" ;
102+
103+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
104+ const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
105+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
106+ [val_b](const CTYPE_COMPUTE val_a) {
107+ return std::pow (val_a, val_b);
108+ },
109+ ctx,
110+ a,
111+ utils::SupportedTensorDtypes::REALHBBF16,
112+ out,
113+ utils::SupportedTensorDtypes::REALHBF16);
162114 });
163115
164116 return out;
@@ -169,49 +121,44 @@ Tensor& pow_Scalar_out(
169121 const Scalar& a,
170122 const Tensor& b,
171123 Tensor& out) {
172- (void )ctx;
124+ // Common Dtype
125+ ScalarType common_type = utils::promote_type_with_scalar (b.scalar_type (), a);
173126
174- // Resize for dynamic shape
175- ET_KERNEL_CHECK_MSG (
127+ // Check Common Dtype
128+ ET_KERNEL_CHECK (
176129 ctx,
177- resize_tensor (out, b.sizes ()) == Error::Ok,
130+ (canCast (common_type, out.scalar_type ()) &&
131+ common_type != ScalarType::Bool),
178132 InvalidArgument,
179- out,
180- " Failed to resize output tensor." );
133+ out);
181134
182- ScalarType a_type = utils::get_scalar_dtype (a);
183- ScalarType b_type = b.scalar_type ();
184- ScalarType common_type =
185- utils::promote_type_with_scalar (b_type, a, /* half_to_float*/ false );
186- ScalarType out_type = out.scalar_type ();
135+ // Check Dim Order
136+ ET_KERNEL_CHECK (
137+ ctx, tensors_have_same_dim_order (b, out), InvalidArgument, out);
187138
188- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
139+ // Resize
140+ ET_KERNEL_CHECK (
141+ ctx, resize_tensor (out, b.sizes ()) == Error::Ok, InvalidArgument, out);
189142
190- if (common_type == ScalarType::Half) {
191- common_type = ScalarType::Float;
143+ // Compute Dtype
144+ ScalarType compute_type = utils::get_compute_type (common_type);
145+ if (compute_type != ScalarType::Float) {
146+ compute_type = ScalarType::Double;
192147 }
193148
194- ET_SWITCH_SCALAR_OBJ_TYPES (a_type, ctx, " pow.Scalar_out" , CTYPE_A, [&]() {
195- ET_SWITCH_REALHB_TYPES (b_type, ctx, " pow.Scalar_out" , CTYPE_B, [&]() {
196- ET_SWITCH_REAL_TYPES (common_type, ctx, " pow.Scalar_out" , CTYPE_IN, [&]() {
197- ET_SWITCH_REALH_TYPES (
198- out_type, ctx, " pow.Scalar_out" , CTYPE_OUT, [&]() {
199- CTYPE_A val_a = 0 ;
200- utils::extract_scalar (a, &val_a);
201-
202- apply_unary_map_fn (
203- [val_a](const CTYPE_B val_b) {
204- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
205- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
206- CTYPE_IN value = std::pow (a_casted, b_casted);
207- return static_cast <CTYPE_OUT>(value);
208- },
209- b.const_data_ptr <CTYPE_B>(),
210- out.mutable_data_ptr <CTYPE_OUT>(),
211- out.numel ());
212- });
213- });
214- });
149+ static constexpr const char op_name[] = " pow.Scalar_out" ;
150+
151+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
152+ const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
153+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
154+ [val_a](const CTYPE_COMPUTE val_b) {
155+ return std::pow (val_a, val_b);
156+ },
157+ ctx,
158+ b,
159+ utils::SupportedTensorDtypes::REALHBBF16,
160+ out,
161+ utils::SupportedTensorDtypes::REALHBF16);
215162 });
216163
217164 return out;
0 commit comments