Skip to content

Commit 8064895

Browse files
authored
Merge pull request #9 from dijopaul/main
Fixing review comments in 5483
2 parents 41d7533 + a8c4f66 commit 8064895

File tree

6 files changed

+380
-241
lines changed

6 files changed

+380
-241
lines changed

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 191 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,82 +9,140 @@
99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1010
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1111
#include <executorch/kernels/portable/cpu/util/functional_util.h>
12+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314
#include <executorch/runtime/platform/assert.h>
14-
#include "kernels.h"
15+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1516

1617
namespace torch {
1718
namespace executor {
1819
namespace native {
20+
namespace {
1921

20-
#define NNLIB_MAX_DIM 4 /* Add fallback if broadcast and dim > 4 */
22+
template <
23+
bool can_cast,
24+
typename CTYPE_A,
25+
typename CTYPE_B,
26+
typename CTYPE_IN,
27+
typename CTYPE_OUT>
28+
struct AddInner;
29+
30+
template <
31+
typename CTYPE_A,
32+
typename CTYPE_B,
33+
typename CTYPE_IN,
34+
typename CTYPE_OUT>
35+
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36+
static void
37+
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43+
CTYPE_IN value = a_casted + alpha_val * b_casted;
44+
45+
return static_cast<CTYPE_OUT>(value);
46+
},
47+
a,
48+
b,
49+
out);
50+
}
51+
};
52+
53+
template <typename CTYPE_IN>
54+
struct ReportCanCastBug {
55+
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57+
}
58+
};
59+
60+
template <
61+
typename CTYPE_A,
62+
typename CTYPE_B,
63+
typename CTYPE_IN,
64+
typename CTYPE_OUT>
65+
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66+
: public ReportCanCastBug<CTYPE_IN> {};
67+
68+
} // namespace
2169

2270
Tensor& add_out(
23-
RuntimeContext& ctx,
71+
KernelRuntimeContext& ctx,
2472
const Tensor& a,
2573
const Tensor& b,
2674
const Scalar& alpha,
2775
Tensor& out) {
28-
(void)ctx;
76+
ET_KERNEL_CHECK(
77+
ctx,
78+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
79+
InvalidArgument,
80+
out);
81+
82+
ET_KERNEL_CHECK(
83+
ctx,
84+
executorch::runtime::tensor_is_realhbbf16_type(out),
85+
InvalidArgument,
86+
out);
87+
ET_KERNEL_CHECK(
88+
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
2989

3090
ScalarType a_type = a.scalar_type();
3191
ScalarType b_type = b.scalar_type();
32-
ScalarType common_type = promoteTypes(a_type, b_type);
92+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
93+
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
3394
ScalarType out_type = out.scalar_type();
3495

35-
ET_CHECK_MSG(a_type == ScalarType::Float, "Input tensor not a float.\n");
36-
ET_CHECK_MSG(b_type == ScalarType::Float, "Input tensor not a float.\n");
37-
ET_CHECK_MSG(out_type == ScalarType::Float, "Output tensor not a float.\n");
38-
39-
ET_CHECK(canCast(common_type, out_type));
40-
41-
using CTYPE_A = float;
42-
using CTYPE_B = float;
43-
using CTYPE_IN = float;
44-
using CTYPE_OUT = float;
45-
CTYPE_IN alpha_val;
46-
ET_EXTRACT_SCALAR(alpha, alpha_val);
96+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
97+
ET_KERNEL_CHECK(
98+
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
99+
100+
float alpha_val;
101+
utils::extract_scalar(alpha, &alpha_val);
47102

103+
constexpr auto name = "add.out";
104+
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
105+
48106
int a_dim = a.dim(), b_dim = b.dim(), out_dim = out.dim();
49-
int fall_back = 0;
107+
bool optimized = 1;
50108
/*find broadcast*/
51-
const int a_is_broadcasted = !out.sizes().equals(a.sizes());
52-
const int b_is_broadcasted = !out.sizes().equals(b.sizes());
53-
const int broadcast = (a_is_broadcasted || b_is_broadcasted);
109+
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
110+
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
111+
const bool broadcast = (a_is_broadcasted || b_is_broadcasted);
54112
int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
55113
max_dim = out.dim() > max_dim ? out.dim() : max_dim;
56114

57-
if( (out_type != ScalarType::Float) || (alpha_val != 1.0))
58-
fall_back = 1;
115+
if((out_type != ScalarType::Float) || (alpha_val != 1.0))
116+
optimized = 0;
59117

60-
if( (a_dim == 0) || (b_dim == 0) )
61-
fall_back = 1;
118+
if((a_dim == 0) || (b_dim == 0) )
119+
optimized = 0;
62120

63-
if((broadcast == 1) && (max_dim > NNLIB_MAX_DIM))
64-
fall_back = 1;
121+
if((broadcast == 1) && (max_dim > kNnlibMaxDim))
122+
optimized = 0;
65123

66124

67-
if (!fall_back)
125+
if(optimized)
68126
{
69127
const float* const a_data = a.const_data_ptr<float>();
70128
const float* const b_data = b.const_data_ptr<float>();
71129
float* const out_data = out.mutable_data_ptr<float>();
72130
if(broadcast == 1)
73131
{
74-
int out_shape[NNLIB_MAX_DIM];
75-
int inp1_shape[NNLIB_MAX_DIM];
76-
int inp2_shape[NNLIB_MAX_DIM];
132+
int out_shape[kNnlibMaxDim];
133+
int inp1_shape[kNnlibMaxDim];
134+
int inp2_shape[kNnlibMaxDim];
77135

78-
for(int i = 0; i < NNLIB_MAX_DIM; i++)
136+
for(int i = 0; i < kNnlibMaxDim; i++)
79137
{
80138
out_shape[i] = 1;
81139
inp1_shape[i] = 1;
82140
inp2_shape[i] = 1;
83141
}
84142

85-
int off_o = NNLIB_MAX_DIM - out.dim();
86-
int off_a = NNLIB_MAX_DIM - a.dim();
87-
int off_b = NNLIB_MAX_DIM - b.dim();
143+
int off_o = kNnlibMaxDim - out.dim();
144+
int off_a = kNnlibMaxDim - a.dim();
145+
int off_b = kNnlibMaxDim - b.dim();
88146

89147
for(int i = 0; i < out.dim(); i++)
90148
out_shape[i+off_o] = out.size(i);
@@ -97,24 +155,109 @@ Tensor& add_out(
97155
b_data, inp2_shape);
98156
}
99157
else
158+
{
100159
xa_nn_elm_add_f32xf32_f32(out_data, a_data, b_data, out.numel());
101-
160+
}
161+
162+
return out;
102163
}
103-
else
104-
{
105-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
106-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
107-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
108-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
109-
CTYPE_IN value = a_casted + alpha_val * b_casted;
110-
111-
return static_cast<CTYPE_OUT>(value);
112-
},
113-
a,
114-
b,
164+
165+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
166+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
167+
using CTYPE_IN = typename torch::executor::
168+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
169+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
170+
CTYPE_IN alpha_val;
171+
utils::extract_scalar(alpha, &alpha_val);
172+
173+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
174+
AddInner<
175+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
176+
CTYPE_A,
177+
CTYPE_B,
178+
CTYPE_IN,
179+
CTYPE_OUT>::run(a, b, alpha_val, out);
180+
});
181+
});
182+
});
183+
184+
return out;
185+
}
186+
187+
Tensor& add_scalar_out(
188+
KernelRuntimeContext& ctx,
189+
const Tensor& a,
190+
const Scalar& b,
191+
const Scalar& alpha,
192+
Tensor& out) {
193+
194+
// Resize for dynamic shape
195+
ET_KERNEL_CHECK_MSG(
196+
ctx,
197+
resize_tensor(out, a.sizes()) == Error::Ok,
198+
InvalidArgument,
199+
out,
200+
"Failed to resize output tensor.");
201+
202+
ET_KERNEL_CHECK(
203+
ctx,
204+
executorch::runtime::tensor_is_realhbbf16_type(out),
205+
InvalidArgument,
115206
out);
207+
ET_KERNEL_CHECK(
208+
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
209+
210+
ScalarType a_type = a.scalar_type();
211+
ScalarType b_type = utils::get_scalar_dtype(b);
212+
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
213+
ScalarType common_type =
214+
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
215+
ScalarType out_type = out.scalar_type();
216+
217+
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
218+
ET_KERNEL_CHECK(
219+
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
220+
221+
/*When Half first compute the result in float precision
222+
and then downcast to half*/
223+
if (common_type == ScalarType::Half) {
224+
common_type = ScalarType::Float;
116225
}
117226

227+
constexpr auto name = "add.Scalar_out";
228+
229+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
230+
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
231+
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
232+
CTYPE_A,
233+
CTYPE_B,
234+
/*half_to_float*/ true>::type;
235+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
236+
237+
CTYPE_B b_val;
238+
utils::extract_scalar(b, &b_val);
239+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
240+
241+
CTYPE_IN alpha_val;
242+
utils::extract_scalar(alpha, &alpha_val);
243+
244+
using CTYPE_OUT = typename std::conditional<
245+
std::is_same<CTYPE_A, internal::F2>::value,
246+
internal::F2,
247+
CTYPE_IN>::type;
248+
249+
apply_unary_map_fn(
250+
[b_casted, alpha_val](const CTYPE_A val_a) {
251+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
252+
CTYPE_IN value = a_casted + alpha_val * b_casted;
253+
return static_cast<CTYPE_OUT>(value);
254+
},
255+
a.const_data_ptr<CTYPE_A>(),
256+
out.mutable_data_ptr<CTYPE_OUT>(),
257+
out.numel());
258+
});
259+
});
260+
118261
return out;
119262
}
120263

0 commit comments

Comments
 (0)