@@ -63,35 +63,16 @@ using op_call_result =
63
63
std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
64
64
65
65
#ifdef ET_USE_PYTORCH_HEADERS
66
- template <typename T>
67
- struct is_vectorized : public std ::false_type {};
68
-
69
- template <typename T>
70
- struct is_vectorized <at::vec::Vectorized<T>> : public std::true_type {};
71
-
72
- // TODO: can_use_vectorized and can_use_vectorized_impl are a failed
73
- // attempt to use SFINAE to detect whether our generic lambda argument
74
- // with deduced return type would compile if it was passed
75
- // Vectorized<CTYPE_COMMON> instead of CTYPE_COMMON. SFINAE does not
76
- // work that way (see
77
- // e.g. https://stackoverflow.com/questions/53344484/hard-error-when-using-stdinvoke-result-t-with-a-generic-lambda,
78
- // https://stackoverflow.com/questions/31368601/how-to-detect-if-a-generic-lambda-is-uncompilable-in-c-14);
79
- // if we really want to do it then we need to at least require that
80
- // our lambdas actively participate in being SFINAE-friendly, as in
81
- // https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable.
82
- template <typename CTYPE_COMMON, typename Op, typename Enable=void , typename ... Args>
83
- struct can_use_vectorized_impl : std::false_type {};
84
- template <typename CTYPE_COMMON, typename Op, typename ... Args>
85
- struct can_use_vectorized_impl <CTYPE_COMMON, Op, typename std::void_t <decltype (std::declval<std::invoke_result_t <
86
- Op,
87
- ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>().store(std::declval<CTYPE_COMMON*>()))>, Args...> : public std::true_type {};// std::bool_constant<is_vectorized<std::invoke_result_t<Op,ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>::value> {};
88
-
89
66
// Can I call a function of type Op with sizeof...(Args) arguments of type
90
67
// at::vec::Vectorized<CTYPE_COMMON>?
91
- // This is not possible in C++17 as the code is currently set up; see TODO above.
92
- template <typename CTYPE_COMMON, typename Op, typename ...Args>
93
- struct can_use_vectorized : public can_use_vectorized_impl <CTYPE_COMMON, Op, void , Args...> {};
94
-
68
+ //
69
+ // See [NOTE: Generic lambdas] below for requirements on Op.
70
+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
71
+ constexpr bool can_use_vectorized () {
72
+ return std::is_invocable_v<
73
+ Op,
74
+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>;
75
+ }
95
76
#endif // ET_USE_PYTORCH_HEADERS
96
77
97
78
template <
@@ -349,6 +330,17 @@ inline void apply_unitensor_elementwise_fn(
349
330
compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
350
331
}
351
332
333
+ /* *
334
+ * Useful for unary elementwise operators. For each element of the
335
+ * input, call Op and write to the corresponding element of the
336
+ * output. Tensor broadcasting is applied wherever it is required.
337
+ *
338
+ * [NOTE: Generic lambdas]: If Op is a *generic* lambda (i.e., one with `auto`
339
+ * parameters; normal lambdas are fine), it must fulfill one of the
340
+ * following conditions. Either:
341
+ * 1) It must in fact compile when passed at::vec::Vectorized<CTYPE_COMMON>, or
342
+ * 2) It must be actively SFINAE-friendly, as per the C++17 examples in https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable .
343
+ */
352
344
template <
353
345
typename CTYPE_COMMON,
354
346
const char * op_name,
@@ -390,6 +382,7 @@ inline void apply_bitensor_elementwise_fn(
390
382
* Useful for bi-tensor elementwise operators. For each element of the inputs,
391
383
* perform a computation and write to the corresponding element of the output.
392
384
* Tensor broadcasting is applied wherever it is required.
385
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
393
386
*/
394
387
template <
395
388
typename CTYPE_COMMON,
@@ -456,6 +449,8 @@ inline void apply_tritensor_elementwise_fn(
456
449
*
457
450
* static constexpr const char op_name[] = "my_op";
458
451
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
452
+ *
453
+ * See [NOTE: Generic lambdas] if you want to pass a generic lambda for compute_fun.
459
454
*/
460
455
template <
461
456
typename CTYPE_COMMON,
0 commit comments