Skip to content

Commit 0c9b35c

Browse files
committed
[Executorch][optimized] Fix op_div impl to use portable for fallback path
Pull Request resolved: pytorch/executorch#6714 Earlier we just copy pasted from portable impl. This diff refactors portable to make it usable from optimized lib. As a result we get all the size reduction benefit from build and size optimizations landed in portable. ghstack-source-id: 252584490 @exported-using-ghexport Differential Revision: [D65606665](https://our.internmc.facebook.com/intern/diff/D65606665/)
1 parent 2bc8bfb commit 0c9b35c

File tree

3 files changed

+5
-80
lines changed

3 files changed

+5
-80
lines changed

kernels/optimized/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ message("Generated files ${gen_command_sources}")
6161
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6262
add_library(optimized_kernels ${_optimized_kernels__srcs})
6363
target_link_libraries(
64-
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
64+
optimized_kernels PRIVATE executorch_core portable_kernels cpublas extension_threadpool
6565
)
6666
target_compile_options(optimized_kernels PUBLIC ${_common_compile_options})
6767
# Build a library for _optimized_kernels_srcs

kernels/optimized/cpu/op_div.cpp

Lines changed: 3 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,13 @@
99
#include <executorch/kernels/optimized/cpu/binary_ops.h>
1010
#include <executorch/kernels/optimized/vec/functional.h>
1111
#include <executorch/kernels/optimized/vec/vec.h>
12-
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14-
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <executorch/kernels/portable/cpu/op_div_impl.h>
1513
#include <executorch/runtime/platform/assert.h>
1614

1715
namespace torch {
1816
namespace executor {
1917
namespace native {
2018

21-
namespace {
22-
23-
ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
24-
ET_CHECK(
25-
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
26-
ET_CHECK(
27-
!isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
28-
29-
if (isFloatingType(a_type) && isFloatingType(b_type)) {
30-
return promoteTypes(a_type, b_type);
31-
} else if (isFloatingType(a_type)) {
32-
return a_type;
33-
} else if (isFloatingType(b_type)) {
34-
return b_type;
35-
}
36-
return ScalarType::Float;
37-
}
38-
39-
} // namespace
40-
4119
Tensor& opt_div_out(
4220
KernelRuntimeContext& ctx,
4321
const Tensor& a,
@@ -163,34 +141,7 @@ Tensor& opt_div_out(
163141
}
164142
});
165143
} else {
166-
ScalarType common_type = get_compute_type(a_type, b_type);
167-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
168-
169-
ET_KERNEL_CHECK(
170-
ctx,
171-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
172-
InvalidArgument,
173-
out);
174-
175-
ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() {
176-
ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() {
177-
ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
178-
ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
179-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
180-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
181-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
182-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
183-
CTYPE_IN value = a_casted / b_casted;
184-
185-
return static_cast<CTYPE_OUT>(value);
186-
},
187-
a,
188-
b,
189-
out);
190-
});
191-
});
192-
});
193-
});
144+
div_out_impl(ctx, a, b, out);
194145
}
195146

196147
return out;
@@ -232,32 +183,7 @@ Tensor& opt_div_scalar_out(
232183
});
233184
});
234185
} else {
235-
ET_SWITCH_REAL_TYPES_AND(
236-
Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
237-
ET_SWITCH_REAL_TYPES_AND(
238-
Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
239-
ET_SWITCH_REAL_TYPES(
240-
common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() {
241-
ET_SWITCH_REAL_TYPES(
242-
out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() {
243-
CTYPE_B b_val;
244-
ET_EXTRACT_SCALAR(b, b_val);
245-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
246-
CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted;
247-
248-
const size_t n = a.numel();
249-
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
250-
CTYPE_OUT* out_data =
251-
out.mutable_data_ptr<CTYPE_OUT>();
252-
for (auto i = 0; i < n; ++i) {
253-
out_data[i] = static_cast<CTYPE_OUT>(
254-
static_cast<CTYPE_IN>(a_data[i]) *
255-
inv_b_casted);
256-
}
257-
});
258-
});
259-
});
260-
});
186+
div_scalar_out_impl(ctx, a, b, out);
261187
}
262188

263189
return out;

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ _OPTIMIZED_ATEN_OPS = (
2020
name = "op_div",
2121
deps = [
2222
":binary_ops",
23-
"//executorch/kernels/portable/cpu:scalar_utils",
24-
"//executorch/kernels/portable/cpu/util:broadcast_util",
23+
"//executorch/kernels/portable/cpu:op_div_impl",
2524
],
2625
),
2726
op_target(name = "op_exp"),

0 commit comments

Comments
 (0)