-
Notifications
You must be signed in to change notification settings - Fork 683
RFC: Specialize for non-mixed-dtype in elementwise_util #9388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 65 commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
31a49e0
Update
swolchok 9fcd885
Update
swolchok 29d6de9
Update
swolchok 79b908c
Update
swolchok fd62a07
Update
swolchok 854c991
Update
swolchok def7ed4
Update
swolchok 40c1b1b
Update
swolchok 7c78357
Update
swolchok 7ba269a
Update
swolchok edd45fb
Update
swolchok b9c545f
Update
swolchok 3091007
Update
swolchok 4a00cac
Update
swolchok 21b81bf
Update
swolchok 4c4add0
Update
swolchok 8782a90
Update
swolchok 75f8970
Update
swolchok 2d19e75
Update
swolchok b61a8a2
Update
swolchok 91161bd
Update
swolchok 4add706
Update
swolchok 5348a92
Update
swolchok 001d72c
Update
swolchok e49080d
Update
swolchok 44ee51a
Update
swolchok f659627
Update
swolchok f1c5429
Update
swolchok b34f04f
Update
swolchok f934bc0
Update
swolchok 3a74f25
Update
swolchok 9a93839
Update
swolchok bb16a55
Update
swolchok 2242f1e
Update
swolchok 7f57a19
Update
swolchok 5d95c06
Update
swolchok 42623bb
Update
swolchok 4553283
Update
swolchok 39610ad
Update
swolchok b3120fa
Update
swolchok ff2c358
Update
swolchok 7086659
Update
swolchok e13de0e
Update
swolchok 943ab82
Update
swolchok f22d039
Update
swolchok 45ce46d
Update
swolchok 754dba4
Update
swolchok 34eb5d4
Update
swolchok ea9dc6f
Update
swolchok 7d7859e
Update
swolchok b98829d
Update
swolchok 3140910
Update
swolchok 946f2e0
Update
swolchok 7f2bbdb
Update
swolchok 960315e
Update
swolchok 9e42e93
Update
swolchok 96d258e
Update
swolchok e6f66ab
Update
swolchok de9d52f
Update
swolchok 20f3046
Update
swolchok 3aa266d
Update
swolchok 3c88a56
Update
swolchok 153735d
Update
swolchok cac4293
Update
swolchok 85451ea
Update
swolchok b0fc7f9
Update
swolchok File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,10 +53,44 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) { | |
namespace internal { | ||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
typename CTYPE_OUT, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn( | ||
inline void dtype_specialized_elementwise_fn_impl( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
Args... inputs) { | ||
constexpr auto kNumInputs = sizeof...(inputs); | ||
ET_DCHECK(((inputs.first->element_size() == sizeof(CTYPE_COMPUTE)) && ...)); | ||
|
||
::executorch::extension::parallel_for( | ||
0, | ||
out.numel(), | ||
::executorch::extension::internal::GRAIN_SIZE, | ||
[&](const auto begin, const auto end) { | ||
std::array<const CTYPE_COMPUTE*, kNumInputs> inputs_data_ptrs = { | ||
inputs.first->template const_data_ptr<CTYPE_COMPUTE>()...}; | ||
|
||
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>(); | ||
|
||
const auto range = | ||
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...); | ||
auto begin_it = range.begin(); | ||
begin_it += begin; | ||
for (; (*begin_it)[0] < end; ++begin_it) { | ||
const auto& indexes = *begin_it; | ||
std::array<CTYPE_COMPUTE, kNumInputs> loaded_inputs; | ||
for (const auto idx : c10::irange(kNumInputs)) { | ||
loaded_inputs[idx] = inputs_data_ptrs[idx][indexes[idx + 1]]; | ||
} | ||
data_out[indexes[0]] = std::apply(compute_fun, loaded_inputs); | ||
} | ||
}); | ||
} | ||
|
||
template <typename CTYPE_COMPUTE, typename Op, typename... Args> | ||
inline bool validate_elementwise_fn_inputs( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
|
@@ -65,7 +99,6 @@ inline void apply_elementwise_fn( | |
static_assert( | ||
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> && | ||
...)); | ||
constexpr auto kNumInputs = sizeof...(inputs); | ||
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value; | ||
const auto check_input_dtype = [](auto input, auto compute_type) { | ||
return internal::check_tensor_dtype( | ||
|
@@ -75,7 +108,24 @@ inline void apply_elementwise_fn( | |
ctx, | ||
(check_input_dtype(inputs, compute_type) && ...) && | ||
internal::check_tensor_dtype(out, out_dtypes, compute_type), | ||
InvalidArgument, ); | ||
InvalidArgument, | ||
false); | ||
|
||
return true; | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn_generic_impl( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes, | ||
Args... inputs) { | ||
constexpr auto kNumInputs = sizeof...(inputs); | ||
|
||
struct InputInfo { | ||
load_to_compute_fn<CTYPE_COMPUTE> load_to_compute; | ||
|
@@ -120,6 +170,64 @@ inline void apply_elementwise_fn( | |
}); | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn_runtime_out_dtypes( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes, | ||
Args... inputs) { | ||
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
if (!inputs_valid) { | ||
return; | ||
} | ||
|
||
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
SupportedTensorDtypes out_dtypes, | ||
typename Op, | ||
typename... Args> | ||
inline void apply_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& out, | ||
Args... inputs) { | ||
const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
if (!inputs_valid) { | ||
return; | ||
} | ||
|
||
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value; | ||
const bool all_inputs_compute_dtype = | ||
((inputs.first->scalar_type() == compute_type) && ...); | ||
|
||
constexpr ScalarType out_specialized_scalar_type = | ||
specialized_output_scalar_type<CTYPE_COMPUTE>(out_dtypes); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here, |
||
if (all_inputs_compute_dtype && | ||
out.scalar_type() == out_specialized_scalar_type) { | ||
using CTYPE_OUT = | ||
typename ScalarTypeToCppType<out_specialized_scalar_type>::type; | ||
dtype_specialized_elementwise_fn_impl<CTYPE_COMPUTE, CTYPE_OUT>( | ||
compute_fun, ctx, out, inputs...); | ||
return; | ||
} | ||
|
||
apply_elementwise_fn_generic_impl<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, inputs...); | ||
} | ||
|
||
/// DEPRECATED: prefer the variant with out_dtypes in the template argument. | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
inline void apply_unitensor_elementwise_fn( | ||
const Op& compute_fun, | ||
|
@@ -128,32 +236,96 @@ inline void apply_unitensor_elementwise_fn( | |
SupportedTensorDtypes a_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes)); | ||
} | ||
|
||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
SupportedTensorDtypes out_dtypes, | ||
typename Op> | ||
inline void apply_unitensor_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
SupportedTensorDtypes a_dtypes, | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); | ||
} | ||
|
||
/** | ||
* DEPRECATED: prefer the variant with out_dtypes in the template argument list. | ||
*/ | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
inline void apply_bitensor_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
SupportedTensorDtypes a_dtypes, | ||
const Tensor& b, | ||
SupportedTensorDtypes b_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
out_dtypes, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes)); | ||
} | ||
|
||
/** | ||
* Useful for bi-tensor elementwise operators. For each element of the inputs, | ||
* perform a computation and write to the corresponding element of the output. | ||
* Tensor broadcasting is applied wherever it is required. | ||
*/ | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
SupportedTensorDtypes out_dtypes, | ||
typename Op> | ||
inline void apply_bitensor_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
SupportedTensorDtypes a_dtypes, | ||
const Tensor& b, | ||
SupportedTensorDtypes b_dtypes, | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes)); | ||
} | ||
|
||
/** | ||
* DEPRECATED: prefer the variant with out_dtypes in the template argument list. | ||
*/ | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
inline void apply_tritensor_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
SupportedTensorDtypes a_dtypes, | ||
const Tensor& b, | ||
SupportedTensorDtypes b_dtypes, | ||
const Tensor& c, | ||
SupportedTensorDtypes c_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
internal::apply_elementwise_fn_runtime_out_dtypes<CTYPE_COMPUTE, op_name>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
out_dtypes, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes)); | ||
std::make_pair(&b, b_dtypes), | ||
std::make_pair(&c, c_dtypes)); | ||
} | ||
|
||
/** | ||
|
@@ -176,7 +348,11 @@ inline void apply_bitensor_elementwise_fn( | |
* static constexpr const char op_name[] = "my_op"; | ||
* apply_ternary_elementwise_fn<CTYPE_COMPUTE, op_name>. | ||
*/ | ||
template <typename CTYPE_COMPUTE, const char* op_name, typename Op> | ||
template < | ||
typename CTYPE_COMPUTE, | ||
const char* op_name, | ||
SupportedTensorDtypes out_dtypes, | ||
typename Op> | ||
inline void apply_tritensor_elementwise_fn( | ||
const Op& compute_fun, | ||
KernelRuntimeContext& ctx, | ||
|
@@ -186,13 +362,11 @@ inline void apply_tritensor_elementwise_fn( | |
SupportedTensorDtypes b_dtypes, | ||
const Tensor& c, | ||
SupportedTensorDtypes c_dtypes, | ||
const Tensor& out, | ||
SupportedTensorDtypes out_dtypes) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name>( | ||
const Tensor& out) { | ||
internal::apply_elementwise_fn<CTYPE_COMPUTE, op_name, out_dtypes>( | ||
compute_fun, | ||
ctx, | ||
out, | ||
out_dtypes, | ||
std::make_pair(&a, a_dtypes), | ||
std::make_pair(&b, b_dtypes), | ||
std::make_pair(&c, c_dtypes)); | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be
CTYPE_COMPUTE
, right?