99#include < cmath>
1010
1111#include < executorch/kernels/portable/cpu/scalar_utils.h>
12- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
13- #include < executorch/kernels/portable/cpu/util/functional_util.h>
14- #include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1513#include < executorch/runtime/kernel/kernel_includes.h>
1614
1715namespace torch {
1816namespace executor {
1917namespace native {
2018
21- using Tensor = exec_aten::Tensor;
22-
23- namespace {
24- template <
25- bool can_cast,
26- typename CTYPE_A,
27- typename CTYPE_B,
28- typename CTYPE_IN,
29- typename CTYPE_OUT>
30- struct PowInner ;
31-
32- template <
33- typename CTYPE_A,
34- typename CTYPE_B,
35- typename CTYPE_IN,
36- typename CTYPE_OUT>
37- struct PowInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
38- static void run (const Tensor& a, const Tensor& b, Tensor& out) {
39- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40- // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41- [](const CTYPE_A val_a, const CTYPE_B val_b) {
42- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
43- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
44- CTYPE_IN value = std::pow (a_casted, b_casted);
45- return static_cast <CTYPE_OUT>(value);
46- },
47- a,
48- b,
49- out);
50- }
51- };
52-
53- struct ReportCanCastBug {
54- static void run (const Tensor&, const Tensor&, Tensor&) {
55- ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56- }
57- };
58-
59- template <
60- typename CTYPE_A,
61- typename CTYPE_B,
62- typename CTYPE_IN,
63- typename CTYPE_OUT>
64- struct PowInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
65- : public ReportCanCastBug {};
66-
67- } // namespace
68-
6919Tensor& pow_Tensor_Tensor_out (
7020 KernelRuntimeContext& ctx,
7121 const Tensor& a,
7222 const Tensor& b,
7323 Tensor& out) {
74- // Determine output size and resize for dynamic shapes
24+ // Common Dtype
25+ ScalarType common_type = promoteTypes (a.scalar_type (), b.scalar_type ());
26+
27+ // Check Common Dtype
7528 ET_KERNEL_CHECK (
7629 ctx,
77- resize_to_broadcast_target_size (a, b, out) == Error::Ok,
30+ (canCast (common_type, out.scalar_type ()) &&
31+ common_type != ScalarType::Bool),
7832 InvalidArgument,
7933 out);
8034
81- ScalarType a_type = a.scalar_type ();
82- ScalarType b_type = b.scalar_type ();
83- ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
84- ScalarType out_type = out.scalar_type ();
35+ // Check Dim Order
36+ ET_KERNEL_CHECK (
37+ ctx, tensors_have_same_dim_order (a, b, out), InvalidArgument, out);
8538
39+ // Resize
8640 ET_KERNEL_CHECK (
87- ctx, common_type != exec_aten::ScalarType::Bool, InvalidArgument, out);
88- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
89-
90- ET_SWITCH_REALHB_TYPES (a_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_A, [&]() {
91- ET_SWITCH_REALHB_TYPES (
92- b_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_B, [&]() {
93- using CTYPE_IN = typename torch::executor::
94- promote_types<CTYPE_A, CTYPE_B, /* half_to_float*/ true >::type;
95- ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
96- ET_SWITCH_REALH_TYPES (
97- out_type, ctx, " pow.Tensor_Tensor_out" , CTYPE_OUT, [&]() {
98- PowInner<
99- !std::is_same<CTYPE_IN, bool >::value &&
100- can_cast<CTYPE_IN, CTYPE_OUT>::value,
101- CTYPE_A,
102- CTYPE_B,
103- CTYPE_IN,
104- CTYPE_OUT>::run (a, b, out);
105- });
106- });
41+ ctx,
42+ resize_to_broadcast_target_size (a, b, out) == Error::Ok,
43+ InvalidArgument,
44+ out);
45+
46+ // Compute Dtype
47+ ScalarType compute_type = utils::get_compute_type (common_type);
48+ if (compute_type != ScalarType::Float) {
49+ compute_type = ScalarType::Double;
50+ }
51+
52+ // @lint-ignore CLANGTIDY facebook-hte-CArray
53+ static constexpr const char op_name[] = " pow.Tensor_Tensor_out" ;
54+
55+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57+ [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58+ return std::pow (val_a, val_b);
59+ },
60+ ctx,
61+ a,
62+ utils::SupportedTensorDtypes::REALHBBF16,
63+ b,
64+ utils::SupportedTensorDtypes::REALHBBF16,
65+ out,
66+ utils::SupportedTensorDtypes::REALHBF16);
10767 });
10868
10969 return out;
@@ -114,51 +74,43 @@ Tensor& pow_Tensor_Scalar_out(
11474 const Tensor& a,
11575 const Scalar& b,
11676 Tensor& out) {
117- (void )ctx;
77+ // Common Dtype
78+ ScalarType common_type = utils::promote_type_with_scalar (a.scalar_type (), b);
11879
119- // Resize for dynamic shape
120- ET_KERNEL_CHECK_MSG (
80+ // Check Common Dtype
81+ ET_KERNEL_CHECK (
12182 ctx,
122- resize_tensor (out, a.sizes ()) == Error::Ok,
83+ (canCast (common_type, out.scalar_type ()) &&
84+ common_type != ScalarType::Bool),
12385 InvalidArgument,
124- out,
125- " Failed to resize output tensor." );
86+ out);
12687
127- ScalarType a_type = a.scalar_type ();
128- ScalarType b_type = utils::get_scalar_dtype (b);
129- ScalarType common_type =
130- utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
131- ScalarType out_type = out.scalar_type ();
88+ // Check Dim Order
89+ ET_KERNEL_CHECK (
90+ ctx, tensors_have_same_dim_order (a, out), InvalidArgument, out);
13291
133- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
92+ // Resize
93+ ET_KERNEL_CHECK (
94+ ctx, resize_tensor (out, a.sizes ()) == Error::Ok, InvalidArgument, out);
13495
135- if (common_type == ScalarType::Half) {
136- common_type = ScalarType::Float;
96+ // Compute Dtype
97+ ScalarType compute_type = utils::get_compute_type (common_type);
98+ if (compute_type != ScalarType::Float) {
99+ compute_type = ScalarType::Double;
137100 }
138101
139- ET_SWITCH_REALHB_TYPES (a_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_A, [&]() {
140- ET_SWITCH_SCALAR_OBJ_TYPES (
141- b_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_B, [&]() {
142- ET_SWITCH_REAL_TYPES (
143- common_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_IN, [&]() {
144- ET_SWITCH_REALH_TYPES (
145- out_type, ctx, " pow.Tensor_Scalar_out" , CTYPE_OUT, [&]() {
146- CTYPE_B val_b = 0 ;
147- utils::extract_scalar (b, &val_b);
148- apply_unary_map_fn (
149- [val_b](const CTYPE_A val_a) {
150- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
151- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
152- CTYPE_IN value = std::pow (a_casted, b_casted);
153-
154- return static_cast <CTYPE_OUT>(value);
155- },
156- a.const_data_ptr <CTYPE_A>(),
157- out.mutable_data_ptr <CTYPE_OUT>(),
158- out.numel ());
159- });
160- });
161- });
102+ // @lint-ignore CLANGTIDY facebook-hte-CArray
103+ static constexpr const char op_name[] = " pow.Tensor_Scalar_out" ;
104+
105+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
106+ const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
107+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
108+ [val_b](const CTYPE_COMPUTE val_a) { return std::pow (val_a, val_b); },
109+ ctx,
110+ a,
111+ utils::SupportedTensorDtypes::REALHBBF16,
112+ out,
113+ utils::SupportedTensorDtypes::REALHBF16);
162114 });
163115
164116 return out;
@@ -169,49 +121,43 @@ Tensor& pow_Scalar_out(
169121 const Scalar& a,
170122 const Tensor& b,
171123 Tensor& out) {
172- (void )ctx;
124+ // Common Dtype
125+ ScalarType common_type = utils::promote_type_with_scalar (b.scalar_type (), a);
173126
174- // Resize for dynamic shape
175- ET_KERNEL_CHECK_MSG (
127+ // Check Common Dtype
128+ ET_KERNEL_CHECK (
176129 ctx,
177- resize_tensor (out, b.sizes ()) == Error::Ok,
130+ (canCast (common_type, out.scalar_type ()) &&
131+ common_type != ScalarType::Bool),
178132 InvalidArgument,
179- out,
180- " Failed to resize output tensor." );
133+ out);
181134
182- ScalarType a_type = utils::get_scalar_dtype (a);
183- ScalarType b_type = b.scalar_type ();
184- ScalarType common_type =
185- utils::promote_type_with_scalar (b_type, a, /* half_to_float*/ false );
186- ScalarType out_type = out.scalar_type ();
135+ // Check Dim Order
136+ ET_KERNEL_CHECK (
137+ ctx, tensors_have_same_dim_order (b, out), InvalidArgument, out);
187138
188- ET_KERNEL_CHECK (ctx, common_type == out_type, InvalidArgument, out);
139+ // Resize
140+ ET_KERNEL_CHECK (
141+ ctx, resize_tensor (out, b.sizes ()) == Error::Ok, InvalidArgument, out);
189142
190- if (common_type == ScalarType::Half) {
191- common_type = ScalarType::Float;
143+ // Compute Dtype
144+ ScalarType compute_type = utils::get_compute_type (common_type);
145+ if (compute_type != ScalarType::Float) {
146+ compute_type = ScalarType::Double;
192147 }
193148
194- ET_SWITCH_SCALAR_OBJ_TYPES (a_type, ctx, " pow.Scalar_out" , CTYPE_A, [&]() {
195- ET_SWITCH_REALHB_TYPES (b_type, ctx, " pow.Scalar_out" , CTYPE_B, [&]() {
196- ET_SWITCH_REAL_TYPES (common_type, ctx, " pow.Scalar_out" , CTYPE_IN, [&]() {
197- ET_SWITCH_REALH_TYPES (
198- out_type, ctx, " pow.Scalar_out" , CTYPE_OUT, [&]() {
199- CTYPE_A val_a = 0 ;
200- utils::extract_scalar (a, &val_a);
201-
202- apply_unary_map_fn (
203- [val_a](const CTYPE_B val_b) {
204- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
205- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
206- CTYPE_IN value = std::pow (a_casted, b_casted);
207- return static_cast <CTYPE_OUT>(value);
208- },
209- b.const_data_ptr <CTYPE_B>(),
210- out.mutable_data_ptr <CTYPE_OUT>(),
211- out.numel ());
212- });
213- });
214- });
149+ // @lint-ignore CLANGTIDY facebook-hte-CArray
150+ static constexpr const char op_name[] = " pow.Scalar_out" ;
151+
152+ ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
153+ const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
154+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
155+ [val_a](const CTYPE_COMPUTE val_b) { return std::pow (val_a, val_b); },
156+ ctx,
157+ b,
158+ utils::SupportedTensorDtypes::REALHBBF16,
159+ out,
160+ utils::SupportedTensorDtypes::REALHBF16);
215161 });
216162
217163 return out;
0 commit comments