@@ -220,30 +220,29 @@ inline void dtype_specialized_elementwise_fn_impl(
220
220
});
221
221
}
222
222
223
- template <typename CTYPE_COMPUTE, typename Op, typename ... Args>
224
- inline bool validate_elementwise_fn_inputs (
225
- const Op& compute_fun,
223
+ bool validate_elementwise_fn_inputs (
226
224
KernelRuntimeContext& ctx,
227
225
const Tensor& out,
228
226
SupportedTensorDtypes out_dtypes,
229
- Args... inputs) {
230
- static_assert (
231
- (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
232
- ...));
233
- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
234
- const auto check_input_dtype = [](auto input, auto compute_type) {
235
- return internal::check_tensor_dtype (
236
- *input.first , input.second , compute_type);
237
- };
238
- ET_KERNEL_CHECK (
239
- ctx,
240
- (check_input_dtype (inputs, compute_type) && ...) &&
241
- internal::check_tensor_dtype (out, out_dtypes, compute_type),
242
- InvalidArgument,
243
- false );
227
+ ScalarType compute_type,
228
+ std::pair<const Tensor*, SupportedTensorDtypes> input);
244
229
245
- return true ;
246
- }
230
+ bool validate_elementwise_fn_inputs (
231
+ KernelRuntimeContext& ctx,
232
+ const Tensor& out,
233
+ SupportedTensorDtypes out_dtypes,
234
+ ScalarType compute_type,
235
+ std::pair<const Tensor*, SupportedTensorDtypes> input0,
236
+ std::pair<const Tensor*, SupportedTensorDtypes> input1);
237
+
238
+ bool validate_elementwise_fn_inputs (
239
+ KernelRuntimeContext& ctx,
240
+ const Tensor& out,
241
+ SupportedTensorDtypes out_dtypes,
242
+ ScalarType compute_type,
243
+ std::pair<const Tensor*, SupportedTensorDtypes> input0,
244
+ std::pair<const Tensor*, SupportedTensorDtypes> input1,
245
+ std::pair<const Tensor*, SupportedTensorDtypes> input2);
247
246
248
247
template <
249
248
typename CTYPE_COMPUTE,
@@ -314,8 +313,9 @@ inline void apply_elementwise_fn_runtime_out_dtypes(
314
313
const Tensor& out,
315
314
SupportedTensorDtypes out_dtypes,
316
315
Args... inputs) {
317
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
318
- compute_fun, ctx, out, out_dtypes, inputs...);
316
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
317
+ const bool inputs_valid = validate_elementwise_fn_inputs (
318
+ ctx, out, out_dtypes, compute_type, inputs...);
319
319
if (!inputs_valid) {
320
320
return ;
321
321
}
@@ -339,13 +339,13 @@ inline void apply_elementwise_fn(
339
339
KernelRuntimeContext& ctx,
340
340
const Tensor& out,
341
341
Args... inputs) {
342
- const bool inputs_valid = validate_elementwise_fn_inputs<CTYPE_COMPUTE>(
343
- compute_fun, ctx, out, out_dtypes, inputs...);
342
+ constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
343
+ const bool inputs_valid = validate_elementwise_fn_inputs (
344
+ ctx, out, out_dtypes, compute_type, inputs...);
344
345
if (!inputs_valid) {
345
346
return ;
346
347
}
347
348
348
- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMPUTE>::value;
349
349
if constexpr (should_include_kernel_dtype (op_name, compute_type)) {
350
350
const bool all_inputs_compute_dtype =
351
351
((inputs.first ->scalar_type () == compute_type) && ...);
0 commit comments