@@ -51,6 +51,13 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
5151}
5252
5353namespace internal {
54+ template <typename Ignore, typename T>
55+ using ignore_first_yield_second = T;
56+
57+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
58+ using op_call_result =
59+ std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
60+
5461template <
5562 typename CTYPE_COMMON,
5663 const char * op_name,
@@ -89,9 +96,16 @@ inline void apply_elementwise_fn(
8996 inputs.first ->element_size (),
9097 })...};
9198
92- const auto store_common_to_out =
93- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
94- out, out_dtypes);
99+ // NOTE: the result of compute_fun is not necessarily CTYPE_COMMON!
100+ // For example, consider the possibility that compute_fun is a
101+ // trigonometric function like acos, the common input type is bool,
102+ // and the output type is float -- we would truncate acos(0) ~= 1.67
103+ // to just 1. Conveniently, it costs us nothing at runtime to handle
104+ // this correctly.
105+ const auto store_compute_result_to_out =
106+ internal::get_store_common_to_tensor_fn<
107+ op_call_result<CTYPE_COMMON, Op, Args...>,
108+ op_name>(out, out_dtypes);
95109 char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
96110 const auto out_element_size = out.element_size ();
97111
@@ -114,7 +128,8 @@ inline void apply_elementwise_fn(
114128 .data_ptr [indexes[idx + 1 ] * input_info.element_size ]);
115129 }
116130 auto result = std::apply (compute_fun, loaded_inputs);
117- store_common_to_out (result, &data_out[indexes[0 ] * out_element_size]);
131+ store_compute_result_to_out (
132+ result, &data_out[indexes[0 ] * out_element_size]);
118133 }
119134 });
120135}
0 commit comments