Skip to content

Commit fd955cf

Browse files
authored
Merge pull request #11 from dijopaul/main
Namespace update as per review comments
2 parents 8064895 + a3581f1 commit fd955cf

File tree

7 files changed

+88
-347
lines changed

7 files changed

+88
-347
lines changed

backends/cadence/aot/functions_hifi.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
- op: add.out
2626
kernels:
2727
- arg_meta: null
28-
kernel_name: torch::executor::add_out
28+
kernel_name: impl::HiFi::add_out
2929

3030
- op: bmm.out
3131
kernels:
@@ -45,12 +45,12 @@
4545
- op: div.out
4646
kernels:
4747
- arg_meta: null
48-
kernel_name: torch::executor::div_out
48+
kernel_name: impl::HiFi::div_out
4949

5050
- op: div.out_mode
5151
kernels:
5252
- arg_meta: null
53-
kernel_name: torch::executor::div_out_mode
53+
kernel_name: impl::HiFi::div_out_mode
5454

5555
- op: embedding.out
5656
kernels:
@@ -65,7 +65,7 @@
6565
- op: mul.out
6666
kernels:
6767
- arg_meta: null
68-
kernel_name: torch::executor::mul_out
68+
kernel_name: impl::HiFi::mul_out
6969

7070
- op: permute_copy.out
7171
kernels:
@@ -75,7 +75,7 @@
7575
- op: sigmoid.out
7676
kernels:
7777
- arg_meta: null
78-
kernel_name: torch::executor::sigmoid_out
78+
kernel_name: impl::HiFi::sigmoid_out
7979

8080
- op: slice_copy.Tensor_out
8181
kernels:
@@ -90,12 +90,12 @@
9090
- op: sub.out
9191
kernels:
9292
- arg_meta: null
93-
kernel_name: torch::executor::sub_out
93+
kernel_name: impl::HiFi::sub_out
9494

9595
- op: tanh.out
9696
kernels:
9797
- arg_meta: null
98-
kernel_name: torch::executor::tanh_out
98+
kernel_name: impl::HiFi::tanh_out
9999

100100
- op: view_copy.out
101101
kernels:

backends/cadence/hifi/operators/op_add.cpp

Lines changed: 17 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,19 @@
1414
#include <executorch/runtime/platform/assert.h>
1515
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1616

17-
namespace torch {
18-
namespace executor {
17+
using exec_aten::Scalar;
18+
using exec_aten::ScalarType;
19+
using exec_aten::Tensor;
20+
using executorch::runtime::can_cast;
21+
using executorch::runtime::CppTypeToScalarType;
22+
using executorch::runtime::KernelRuntimeContext;
23+
using torch::executor::Error;
24+
25+
namespace impl {
26+
namespace HiFi {
1927
namespace native {
20-
namespace {
2128

29+
namespace {
2230
template <
2331
bool can_cast,
2432
typename CTYPE_A,
@@ -35,7 +43,7 @@ template <
3543
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
3644
static void
3745
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
46+
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
3947
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
4048
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
4149
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
@@ -89,7 +97,7 @@ Tensor& add_out(
8997

9098
ScalarType a_type = a.scalar_type();
9199
ScalarType b_type = b.scalar_type();
92-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
100+
ScalarType alpha_type = torch::executor::native::utils::get_scalar_dtype(alpha);
93101
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
94102
ScalarType out_type = out.scalar_type();
95103

@@ -98,7 +106,7 @@ Tensor& add_out(
98106
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
99107

100108
float alpha_val;
101-
utils::extract_scalar(alpha, &alpha_val);
109+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
102110

103111
constexpr auto name = "add.out";
104112
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */
@@ -168,7 +176,7 @@ Tensor& add_out(
168176
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
169177
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
170178
CTYPE_IN alpha_val;
171-
utils::extract_scalar(alpha, &alpha_val);
179+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);
172180

173181
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
174182
AddInner<
@@ -184,83 +192,7 @@ Tensor& add_out(
184192
return out;
185193
}
186194

187-
Tensor& add_scalar_out(
188-
KernelRuntimeContext& ctx,
189-
const Tensor& a,
190-
const Scalar& b,
191-
const Scalar& alpha,
192-
Tensor& out) {
193-
194-
// Resize for dynamic shape
195-
ET_KERNEL_CHECK_MSG(
196-
ctx,
197-
resize_tensor(out, a.sizes()) == Error::Ok,
198-
InvalidArgument,
199-
out,
200-
"Failed to resize output tensor.");
201-
202-
ET_KERNEL_CHECK(
203-
ctx,
204-
executorch::runtime::tensor_is_realhbbf16_type(out),
205-
InvalidArgument,
206-
out);
207-
ET_KERNEL_CHECK(
208-
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
209-
210-
ScalarType a_type = a.scalar_type();
211-
ScalarType b_type = utils::get_scalar_dtype(b);
212-
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
213-
ScalarType common_type =
214-
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
215-
ScalarType out_type = out.scalar_type();
216-
217-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
218-
ET_KERNEL_CHECK(
219-
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
220-
221-
/*When Half first compute the result in float precision
222-
and then downcast to half*/
223-
if (common_type == ScalarType::Half) {
224-
common_type = ScalarType::Float;
225-
}
226-
227-
constexpr auto name = "add.Scalar_out";
228-
229-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
230-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
231-
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
232-
CTYPE_A,
233-
CTYPE_B,
234-
/*half_to_float*/ true>::type;
235-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
236-
237-
CTYPE_B b_val;
238-
utils::extract_scalar(b, &b_val);
239-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
240-
241-
CTYPE_IN alpha_val;
242-
utils::extract_scalar(alpha, &alpha_val);
243-
244-
using CTYPE_OUT = typename std::conditional<
245-
std::is_same<CTYPE_A, internal::F2>::value,
246-
internal::F2,
247-
CTYPE_IN>::type;
248-
249-
apply_unary_map_fn(
250-
[b_casted, alpha_val](const CTYPE_A val_a) {
251-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
252-
CTYPE_IN value = a_casted + alpha_val * b_casted;
253-
return static_cast<CTYPE_OUT>(value);
254-
},
255-
a.const_data_ptr<CTYPE_A>(),
256-
out.mutable_data_ptr<CTYPE_OUT>(),
257-
out.numel());
258-
});
259-
});
260-
261-
return out;
262-
}
263195

196+
} // namespace impl
197+
} // namespace HiFi
264198
} // namespace native
265-
} // namespace executor
266-
} // namespace torch

backends/cadence/hifi/operators/op_div.cpp

Lines changed: 12 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
#include <cmath>
1616
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1717

18-
namespace torch {
19-
namespace executor {
18+
using exec_aten::Scalar;
19+
using exec_aten::ScalarType;
20+
using exec_aten::Tensor;
21+
using executorch::aten::RuntimeContext;
22+
using torch::executor::Error;
23+
24+
namespace impl {
25+
namespace HiFi {
2026
namespace native {
2127

2228
namespace {
@@ -127,7 +133,7 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
127133
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out", CTYPE_B, [&]() {
128134
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
129135
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
130-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
136+
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
131137
[](const CTYPE_A val_a, const CTYPE_B val_b) {
132138
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
133139
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -242,7 +248,7 @@ Tensor& div_out_mode(
242248
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out_mode", CTYPE_B, [&]() {
243249
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out_mode", CTYPE_IN, [&]() {
244250
ET_SWITCH_REAL_TYPES(out_type, ctx, "div.out_mode", CTYPE_OUT, [&]() {
245-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
251+
torch::executor::apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
246252
[mode](const CTYPE_A val_a, const CTYPE_B val_b) {
247253
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
248254
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
@@ -265,103 +271,7 @@ Tensor& div_out_mode(
265271
return out;
266272
}
267273

268-
Tensor& div_scalar_out(
269-
RuntimeContext& ctx,
270-
const Tensor& a,
271-
const Scalar& b,
272-
Tensor& out) {
273-
(void)ctx;
274-
275-
// Resize for dynamic shape
276-
ET_KERNEL_CHECK_MSG(
277-
ctx,
278-
resize_tensor(out, a.sizes()) == Error::Ok,
279-
InvalidArgument,
280-
out,
281-
"Failed to resize output tensor.");
282-
283-
ScalarType a_type = a.scalar_type();
284-
ScalarType b_type = utils::get_scalar_dtype(b);
285-
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
286-
ScalarType out_type = out.scalar_type();
287-
288-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
289-
290-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
291-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
292-
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.Scalar_out", CTYPE, [&]() {
293-
CTYPE_B b_val;
294-
utils::extract_scalar(b, &b_val);
295-
CTYPE b_casted = static_cast<CTYPE>(b_val);
296-
297-
apply_unary_map_fn(
298-
[b_casted](const CTYPE_A val_a) {
299-
CTYPE a_casted = static_cast<CTYPE>(val_a);
300-
CTYPE value = a_casted / b_casted;
301-
return static_cast<CTYPE>(value);
302-
},
303-
a.const_data_ptr<CTYPE_A>(),
304-
out.mutable_data_ptr<CTYPE>(),
305-
out.numel());
306-
});
307-
});
308-
});
309-
310-
return out;
311-
}
312-
313-
Tensor& div_scalar_mode_out(
314-
RuntimeContext& ctx,
315-
const Tensor& a,
316-
const Scalar& b,
317-
exec_aten::optional<exec_aten::string_view> mode,
318-
Tensor& out) {
319-
320-
// Resize for dynamic shape
321-
ET_KERNEL_CHECK_MSG(
322-
ctx,
323-
resize_tensor(out, a.sizes()) == Error::Ok,
324-
InvalidArgument,
325-
out,
326-
"Failed to resize output tensor.");
327-
328-
ScalarType a_type = a.scalar_type();
329-
ScalarType b_type = utils::get_scalar_dtype(b);
330-
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
331-
ScalarType out_type = out.scalar_type();
332-
333-
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
334-
335-
constexpr auto name = "div.Scalar_mode_out";
336-
337-
ET_SWITCH_REALB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
338-
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
339-
ET_SWITCH_REAL_TYPES(out_type, ctx, name, CTYPE, [&]() {
340-
CTYPE_B b_val;
341-
utils::extract_scalar(b, &b_val);
342-
CTYPE b_casted = static_cast<CTYPE>(b_val);
343-
344-
apply_unary_map_fn(
345-
[b_casted, mode](const CTYPE_A val_a) {
346-
CTYPE a_casted = static_cast<CTYPE>(val_a);
347-
CTYPE value = a_casted / b_casted;
348-
if (mode.has_value() && mode.value() == "trunc") {
349-
value = std::trunc(value);
350-
} else if (mode.has_value() && mode.value() == "floor") {
351-
value = utils::floor_divide(a_casted, b_casted);
352-
}
353-
return value;
354-
},
355-
a.const_data_ptr<CTYPE_A>(),
356-
out.mutable_data_ptr<CTYPE>(),
357-
out.numel());
358-
});
359-
});
360-
});
361-
362-
return out;
363-
}
364274

275+
} // namespace impl
276+
} // namespace HiFi
365277
} // namespace native
366-
} // namespace executor
367-
} // namespace torch

0 commit comments

Comments
 (0)