Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,43 @@ Tensor& add_out(
Tensor& out) {
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(b) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(b) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
(canCast(common_type, out.scalar_type()) &&
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
// Resize
ET_KERNEL_CHECK(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

static constexpr const char op_name[] = "add.out";

ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[alpha](const CTYPE_COMMON val_a, const CTYPE_COMMON val_b) {
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a + val_alpha * val_b;
},
a,
Expand All @@ -73,52 +80,48 @@ Tensor& add_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
(void)ctx;

// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, a.sizes()) == Error::Ok,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out,
"Failed to resize output tensor.");
out);

// Common Dtype
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(a) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
(common_type == out.scalar_type() &&
check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

ScalarType a_type = a.scalar_type();
ScalarType alpha_type = utils::get_scalar_dtype(alpha);
ScalarType common_type =
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
ScalarType out_type = out.scalar_type();

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
// Resize
ET_KERNEL_CHECK(
ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out);
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);

if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}
// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[b, alpha](const CTYPE_COMMON val_a) {
CTYPE_COMMON val_b = utils::scalar_to<CTYPE_COMMON>(b);
CTYPE_COMMON val_alpha = utils::scalar_to<CTYPE_COMMON>(alpha);
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[b, alpha](const CTYPE_COMPUTE val_a) {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
},
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
Expand Down
126 changes: 76 additions & 50 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,74 +73,90 @@ Tensor& clamp_out(
const exec_aten::optional<Scalar>& min_opt,
const exec_aten::optional<Scalar>& max_opt,
Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

bool has_min = min_opt.has_value();
bool has_max = max_opt.has_value();

ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, in.sizes()) == Error::Ok,
has_min || has_max,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
"At least one of 'min' or 'max' must not be None");

// Input Dtypes
ScalarType in_type = in.scalar_type();
ScalarType min_type = in_type;
ScalarType max_type = in_type;
ScalarType common_type = in_type;
ScalarType min_type =
has_min ? utils::get_scalar_dtype(min_opt.value()) : in_type;
ScalarType max_type =
has_max ? utils::get_scalar_dtype(max_opt.value()) : in_type;
ScalarType out_type = out.scalar_type();

bool has_min = min_opt.has_value();
// Common Dtype
ScalarType common_type = in_type;
if (has_min) {
min_type = utils::get_scalar_dtype(min_opt.value());
common_type = utils::promote_type_with_scalar(common_type, min_opt.value());
}
if (has_max) {
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
}

// Check Common Dtype
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

// Check Scalar Bounds
if (has_min) {
ET_KERNEL_CHECK(
ctx,
check_bounds(min_opt.value(), min_type, out_type, "minimum"),
InvalidArgument,
out);
}
bool has_max = max_opt.has_value();
if (has_max) {
max_type = utils::get_scalar_dtype(max_opt.value());
common_type = utils::promote_type_with_scalar(common_type, max_opt.value());
ET_KERNEL_CHECK(
ctx,
check_bounds(max_opt.value(), max_type, out_type, "maximum"),
InvalidArgument,
out);
}

ET_KERNEL_CHECK_MSG(
ctx,
has_min || has_max,
InvalidArgument,
out,
"At least one of 'min' or 'max' must not be None");
// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);
// Resize
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

static constexpr const char op_name[] = "clamp.out";

ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMMON, op_name>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMMON val_in) {
CTYPE_COMMON val_out = val_in;
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(
val_out, utils::scalar_to<CTYPE_COMMON>(min_opt.value()));
val_out, utils::scalar_to<CTYPE_COMPUTE>(min_opt.value()));
}
if (has_max) {
val_out = utils::min_override(
val_out, utils::scalar_to<CTYPE_COMMON>(max_opt.value()));
val_out, utils::scalar_to<CTYPE_COMPUTE>(max_opt.value()));
}
return val_out;
},
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
Expand All @@ -152,8 +168,6 @@ Tensor& clamp_tensor_out(
const exec_aten::optional<Tensor>& min_opt,
const exec_aten::optional<Tensor>& max_opt,
Tensor& out) {
(void)ctx;

bool has_min = min_opt.has_value();
bool has_max = max_opt.has_value();

Expand All @@ -167,42 +181,54 @@ Tensor& clamp_tensor_out(
const Tensor& min = has_min ? min_opt.value() : in;
const Tensor& max = has_max ? max_opt.value() : in;

ET_KERNEL_CHECK(
ctx,
(executorch::runtime::tensor_is_realhbbf16_type(in) &&
executorch::runtime::tensor_is_realhbbf16_type(min) &&
executorch::runtime::tensor_is_realhbbf16_type(max) &&
executorch::runtime::tensor_is_realhbbf16_type(out)),
InvalidArgument,
out);

// Common Dtype
ScalarType common_type = in.scalar_type();
if (has_min) {
common_type = promoteTypes(common_type, min.scalar_type());
}
if (has_max) {
common_type = promoteTypes(common_type, max.scalar_type());
}

// Check Common Dtype
ET_KERNEL_CHECK(
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx,
tensors_have_same_dim_order(in, min, max, out),
InvalidArgument,
out);

// Resize
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(in, min, max, out) == Error::Ok,
InvalidArgument,
out);

ScalarType in_type = in.scalar_type();
ScalarType min_type = min.scalar_type();
ScalarType max_type = max.scalar_type();
ScalarType common_type = in_type;
ScalarType out_type = out.scalar_type();

if (has_min) {
common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true);
}
if (has_max) {
common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true);
}

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMMON, op_name>(
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[has_min, has_max](
const CTYPE_COMMON val_in,
const CTYPE_COMMON val_min,
const CTYPE_COMMON val_max) {
CTYPE_COMMON val_out = val_in;
const CTYPE_COMPUTE val_in,
const CTYPE_COMPUTE val_min,
const CTYPE_COMPUTE val_max) {
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(val_out, val_min);
}
Expand Down
Loading
Loading