Skip to content

Commit 2d19e75

Browse files
committed
Update
[ghstack-poisoned]
1 parent 75f8970 commit 2d19e75

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,35 +63,16 @@ using op_call_result =
6363
std::invoke_result_t<Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6464

6565
#ifdef ET_USE_PYTORCH_HEADERS
66-
template <typename T>
67-
struct is_vectorized : public std::false_type {};
68-
69-
template <typename T>
70-
struct is_vectorized<at::vec::Vectorized<T>> : public std::true_type {};
71-
72-
// TODO: can_use_vectorized and can_use_vectorized_impl are a failed
73-
// attempt to use SFINAE to detect whether our generic lambda argument
74-
// with deduced return type would compile if it was passed
75-
// Vectorized<CTYPE_COMMON> instead of CTYPE_COMMON. SFINAE does not
76-
// work that way (see
77-
// e.g. https://stackoverflow.com/questions/53344484/hard-error-when-using-stdinvoke-result-t-with-a-generic-lambda,
78-
// https://stackoverflow.com/questions/31368601/how-to-detect-if-a-generic-lambda-is-uncompilable-in-c-14);
79-
// if we really want to do it then we need to at least require that
80-
// our lambdas actively participate in being SFINAE-friendly, as in
81-
// https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable.
82-
template <typename CTYPE_COMMON, typename Op, typename Enable=void, typename... Args>
83-
struct can_use_vectorized_impl : std::false_type {};
84-
template <typename CTYPE_COMMON, typename Op, typename... Args>
85-
struct can_use_vectorized_impl<CTYPE_COMMON, Op, typename std::void_t<decltype(std::declval<std::invoke_result_t<
86-
Op,
87-
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>().store(std::declval<CTYPE_COMMON*>()))>, Args...> : public std::true_type {};//std::bool_constant<is_vectorized<std::invoke_result_t<Op,ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>::value> {};
88-
8966
// Can I call a function of type Op with sizeof...(Args) arguments of type
9067
// at::vec::Vectorized<CTYPE_COMMON>?
91-
// This is not possible in C++17 as the code is currently set up; see TODO above.
92-
template <typename CTYPE_COMMON, typename Op, typename...Args>
93-
struct can_use_vectorized : public can_use_vectorized_impl<CTYPE_COMMON, Op, void, Args...> {};
94-
68+
//
69+
// See [NOTE: Generic lambdas] below for requirements on Op.
70+
template <typename CTYPE_COMMON, typename Op, typename... Args>
71+
constexpr bool can_use_vectorized() {
72+
return std::is_invocable_v<
73+
Op,
74+
ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75+
}
9576
#endif // ET_USE_PYTORCH_HEADERS
9677

9778
template <
@@ -349,6 +330,17 @@ inline void apply_unitensor_elementwise_fn(
349330
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
350331
}
351332

333+
/**
334+
* Useful for unary elementwise operators. For each element of the
335+
* input, call Op and write to the corresponding element of the
336+
* output. Tensor broadcasting is applied wherever it is required.
337+
*
338+
* [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339+
* parameters; normal lambdas are fine), it must fulfill one of the
340+
* following conditions. Either:
341+
* 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342+
* 2) It must be actively SFINAE-friendly, as per the C++17 examples in https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable .
343+
*/
352344
template <
353345
typename CTYPE_COMMON,
354346
const char* op_name,
@@ -390,6 +382,7 @@ inline void apply_bitensor_elementwise_fn(
390382
* Useful for bi-tensor elementwise operators. For each element of the inputs,
391383
* perform a computation and write to the corresponding element of the output.
392384
* Tensor broadcasting is applied wherever it is required.
385+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
393386
*/
394387
template <
395388
typename CTYPE_COMMON,
@@ -456,6 +449,8 @@ inline void apply_tritensor_elementwise_fn(
456449
*
457450
* static constexpr const char op_name[] = "my_op";
458451
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
452+
*
453+
* See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
459454
*/
460455
template <
461456
typename CTYPE_COMMON,

0 commit comments

Comments
 (0)