@@ -20,34 +20,42 @@ inline constexpr bool should_include_kernel_dtype(
2020 const char *operator_name,
2121 executorch::aten::ScalarType scalar_type
2222) {
23- return ((executorch::aten::string_view (operator_name).compare (" add.out" ) == 0 )
24- && (scalar_type == executorch::aten::ScalarType::Float || scalar_type == executorch::aten::ScalarType::Int));
25- // || ((executorch::aten::string_view(operator_name).compare("mul.out") == 0)
23+ // return (((std::string_view(operator_name).compare("add.out") == 0))
2624// && (scalar_type == executorch::aten::ScalarType::Float));
27- // || ((executorch::aten::string_view(operator_name).compare("sub.out") == 0)
28- // && (true));
29- }
30-
31- // inline constexpr bool should_include_kernel_dtype(
32- // const char *operator_name,
33- // executorch::aten::ScalarType scalar_type
34- // ) {
3525// return ((executorch::aten::string_view(operator_name).compare("add.out") == 0)
3626// && (true))
3727// || ((executorch::aten::string_view(operator_name).compare("mm.out") == 0)
3828// && (true));
3929// }
30+
31+
32+ return ((std::string_view (operator_name).compare (" _native_batch_norm_legit_no_training.out" ) == 0 )
33+ && (scalar_type == executorch::aten::ScalarType::Float))
34+ || ((std::string_view (operator_name).compare (" add.out" ) == 0 )
35+ && (scalar_type == executorch::aten::ScalarType::Float))
36+ || ((std::string_view (operator_name).compare (" addmm.out" ) == 0 )
37+ && (scalar_type == executorch::aten::ScalarType::Float))
38+ || ((std::string_view (operator_name).compare (" clone.out" ) == 0 )
39+ && (scalar_type == executorch::aten::ScalarType::Float))
40+ || ((std::string_view (operator_name).compare (" convolution.out" ) == 0 )
41+ && (scalar_type == executorch::aten::ScalarType::Float))
42+ || ((std::string_view (operator_name).compare (" hardtanh.out" ) == 0 )
43+ && (scalar_type == executorch::aten::ScalarType::Float))
44+ // || ((std::string_view(operator_name).compare("hardtanh.out") == 0)
45+ // && (scalar_type == executorch::aten::ScalarType::Double))
46+ || ((std::string_view (operator_name).compare (" mean.out" ) == 0 )
47+ && (scalar_type == executorch::aten::ScalarType::Float))
48+ || ((std::string_view (operator_name).compare (" mean_dim.out" ) == 0 )
49+ && (scalar_type == executorch::aten::ScalarType::Float))
50+ || ((std::string_view (operator_name).compare (" permute_copy.out" ) == 0 )
51+ && (scalar_type == executorch::aten::ScalarType::Float));
52+ }
4053/*
4154inline constexpr bool should_include_kernel_dtype(
4255 const char* ,//operator_name,
4356 executorch::aten::ScalarType //scalar_type*
4457) {
4558 return true;
46-
47- // return ((executorch::aten::string_view(operator_name).compare("my_ops::mul3.out") == 0)
48- // && (true))
49- // || ((executorch::aten::string_view(operator_name).compare("my_ops::mul4.out") == 0)
50- // && (true));
5159}
5260*/
5361#endif
0 commit comments