Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 3 additions & 82 deletions kernels/portable/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/kernels/portable/cpu/op_mul_impl.h>

namespace torch {
namespace executor {
Expand All @@ -20,91 +17,15 @@ Tensor& mul_out(
const Tensor& a,
const Tensor& b,
Tensor& out) {
// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

// Check Common Dtype
ET_KERNEL_CHECK(
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

// Resize
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.out";

ET_KERNEL_CHECK(
ctx,
(executorch::runtime::isRealType(compute_type) ||
compute_type == ScalarType::Bool),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
return mul_out_impl(ctx, a, b, out);
}

Tensor& mul_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
Tensor& out) {
// Common Dtype
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);

// Check Common Dtype
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

// Resize
ET_KERNEL_CHECK(
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
return mul_scalar_out_impl(ctx, a, b, out);
}

} // namespace native
Expand Down
114 changes: 114 additions & 0 deletions kernels/portable/cpu/op_mul_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

#include <executorch/kernels/portable/cpu/op_mul_impl.h>

namespace torch {
namespace executor {
namespace native {

Tensor& mul_out_impl(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out) {
// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

// Check Common Dtype
ET_KERNEL_CHECK(
ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

// Resize
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.out";

ET_KERNEL_CHECK(
ctx,
(executorch::runtime::isRealType(compute_type) ||
compute_type == ScalarType::Bool),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
}

Tensor& mul_scalar_out_impl(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
Tensor& out) {
// Common Dtype
ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);

// Check Common Dtype
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

// Resize
ET_KERNEL_CHECK(
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
29 changes: 29 additions & 0 deletions kernels/portable/cpu/op_mul_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

Tensor& mul_out_impl(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out);

Tensor& mul_scalar_out_impl(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
Tensor& out);

} // namespace native
} // namespace executor
} // namespace torch
21 changes: 21 additions & 0 deletions kernels/portable/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ def define_common_targets():
],
)

runtime.cxx_library(
name = "op_mul_impl",
srcs = ["op_mul_impl.cpp"],
exported_headers = ["op_mul_impl.h"],
visibility = [
"//executorch/kernels/portable/cpu/...",
"//executorch/kernels/optimized/cpu/...",
"//executorch/kernels/portable/test/...",
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:dtype_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
"//executorch/kernels/portable/cpu/util:math_util",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/runtime/kernel:kernel_includes",
],
)

# The following will not participate in dtype selective build because
# they are refactored such to be used in optimized op implementations as well
# and we have not enabled selective build for optimized ops.
Expand All @@ -123,6 +143,7 @@ def define_common_targets():
name = "all_impl_deps",
deps = [
"//executorch/kernels/portable/cpu:op_div_impl",
"//executorch/kernels/portable/cpu:op_mul_impl",
],
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -842,10 +842,7 @@ ATEN_OPS = (
op_target(
name = "op_mul",
deps = [
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:dtype_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
":scalar_utils",
"//executorch/kernels/portable/cpu:op_mul_impl",
],
),
op_target(
Expand Down Expand Up @@ -1271,4 +1268,4 @@ def portable_source_list():

def portable_header_list():
"""All the header file names from //executorch/kernels/portable/cpu/"""
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h"]
return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"]
Loading