@@ -220,30 +220,29 @@ inline void dtype_specialized_elementwise_fn_impl(
220220 });
221221}
222222
223- template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
224- inline bool validate_elementwise_fn_inputs (
225- const Op& compute_fun,
223+ bool validate_elementwise_fn_inputs (
226224 KernelRuntimeContext& ctx,
227225 const Tensor& out,
228226 SupportedTensorDtypes out_dtypes,
229- Args... inputs) {
230- static_assert (
231- (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
232- ...));
233- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
234- const auto check_input_dtype = [](auto input, auto compute_type) {
235- return internal::check_tensor_dtype (
236- *input.first , input.second , compute_type);
237- };
238- ET_KERNEL_CHECK (
239- ctx,
240- (check_input_dtype (inputs, compute_type) && ...) &&
241- internal::check_tensor_dtype (out, out_dtypes, compute_type),
242- InvalidArgument,
243- false );
227+ ScalarType compute_type,
228+ std::pair<const Tensor*, SupportedTensorDtypes> input);
244229
245- return true ;
246- }
230+ bool validate_elementwise_fn_inputs (
231+ KernelRuntimeContext& ctx,
232+ const Tensor& out,
233+ SupportedTensorDtypes out_dtypes,
234+ ScalarType compute_type,
235+ std::pair<const Tensor*, SupportedTensorDtypes> input0,
236+ std::pair<const Tensor*, SupportedTensorDtypes> input1);
237+
238+ bool validate_elementwise_fn_inputs (
239+ KernelRuntimeContext& ctx,
240+ const Tensor& out,
241+ SupportedTensorDtypes out_dtypes,
242+ ScalarType compute_type,
243+ std::pair<const Tensor*, SupportedTensorDtypes> input0,
244+ std::pair<const Tensor*, SupportedTensorDtypes> input1,
245+ std::pair<const Tensor*, SupportedTensorDtypes> input2);
247246
248247template <
249248 typename CTYPE_COMPUTE,
@@ -314,8 +313,9 @@ inline void apply_elementwise_fn_runtime_out_dtypes(
314313 const Tensor& out,
315314 SupportedTensorDtypes out_dtypes,
316315 Args... inputs) {
317- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
318- compute_fun, ctx, out, out_dtypes, inputs...);
316+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
317+ const bool inputs_valid = validate_elementwise_fn_inputs (
318+ ctx, out, out_dtypes, compute_type, inputs...);
319319 if (!inputs_valid) {
320320 return ;
321321 }
@@ -339,13 +339,13 @@ inline void apply_elementwise_fn(
339339 KernelRuntimeContext& ctx,
340340 const Tensor& out,
341341 Args... inputs) {
342- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
343- compute_fun, ctx, out, out_dtypes, inputs...);
342+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
343+ const bool inputs_valid = validate_elementwise_fn_inputs (
344+ ctx, out, out_dtypes, compute_type, inputs...);
344345 if (!inputs_valid) {
345346 return ;
346347 }
347348
348- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
349349 if constexpr (should_include_kernel_dtype (op_name, compute_type)) {
350350 const bool all_inputs_compute_dtype =
351351 ((inputs.first ->scalar_type () == compute_type) && ...);
0 commit comments