Skip to content

Commit c3085d5

Browse files
Update on "[ET][Portable][Build Size] REALHBF16 binary ops: maximum, minimum, mul"
- mul: 1.69 M -> 15 K - maximum: 353 K -> 11 K - minimum: 353 K -> 11 K Differential Revision: [D63909726](https://our.internmc.facebook.com/intern/diff/D63909726/) [ghstack-poisoned]
2 parents 85fbe0d + 4b833d3 commit c3085d5

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

kernels/portable/cpu/op_mul.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ Tensor& mul_out(
4343

4444
static constexpr const char op_name[] = "mul.out";
4545

46+
ET_KERNEL_CHECK(
47+
ctx,
48+
(executorch::runtime::isRealType(compute_type) || compute_type == ScalarType::Bool),
49+
InvalidArgument,
50+
out);
51+
4652
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
4753
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
4854
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
@@ -87,9 +93,7 @@ Tensor& mul_scalar_out(
8793
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
8894
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
8995
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
90-
[val_b](const CTYPE_COMPUTE val_a) {
91-
return val_a * val_b;
92-
},
96+
[val_b](const CTYPE_COMPUTE val_a) {return val_a * val_b;},
9397
ctx,
9498
a,
9599
utils::SupportedTensorDtypes::REALHBBF16,

kernels/portable/test/op_mul_test.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ TEST_F(OpMulOutKernelTest, UnhandledDtypeDies) {
4949
std::vector<exec_aten::qint8> b_data(a_data);
5050
std::vector<exec_aten::qint8> out_data(a_data);
5151

52+
std::vector<exec_aten::DimOrderType> dim_order = {0, 1};
53+
5254
auto a_impl = torch::executor::TensorImpl(
53-
ScalarType::QInt8, 2, sizes.data(), a_data.data());
55+
ScalarType::QInt8, 2, sizes.data(), a_data.data(), dim_order.data());
5456
auto b_impl = torch::executor::TensorImpl(
55-
ScalarType::QInt8, 2, sizes.data(), b_data.data());
57+
ScalarType::QInt8, 2, sizes.data(), b_data.data(), dim_order.data());
5658
auto out_impl = torch::executor::TensorImpl(
57-
ScalarType::QInt8, 2, sizes.data(), out_data.data());
59+
ScalarType::QInt8, 2, sizes.data(), out_data.data(), dim_order.data());
5860

5961
// Two input tensors.
6062
Tensor a(&a_impl);

0 commit comments

Comments
 (0)