@@ -328,6 +328,9 @@ TEST_F(OpVarOutTest, InvalidDTypeDies) {
328328}
329329
330330TEST_F (OpVarOutTest, AllFloatInputFloatOutputPasses) {
331+ if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
332+ GTEST_SKIP () << " ATen supports fewer dtypes" ;
333+ }
331334 // Use a two layer switch to hanldle each possible data pair
332335#define TEST_KERNEL (INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE ) \
333336 test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
@@ -340,6 +343,22 @@ TEST_F(OpVarOutTest, AllFloatInputFloatOutputPasses) {
340343#undef TEST_KERNEL
341344}
342345
346+ TEST_F (OpVarOutTest, AllFloatInputFloatOutputPasses_Aten) {
347+ if (!torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
348+ GTEST_SKIP () << " ATen-specific variant of test case" ;
349+ }
350+ // Use a two layer switch to hanldle each possible data pair
351+ #define TEST_KERNEL (INPUT_CTYPE, INPUT_DTYPE, OUTPUT_CTYPE, OUTPUT_DTYPE ) \
352+ test_var_out_dtype<ScalarType::INPUT_DTYPE, ScalarType::OUTPUT_DTYPE>();
353+
354+ #define TEST_ENTRY (INPUT_CTYPE, INPUT_DTYPE ) \
355+ ET_FORALL_FLOAT_TYPES_WITH2 (INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
356+
357+ ET_FORALL_FLOAT_TYPES (TEST_ENTRY);
358+ #undef TEST_ENTRY
359+ #undef TEST_KERNEL
360+ }
361+
343362TEST_F (OpVarOutTest, InfinityAndNANTest) {
344363 TensorFactory<ScalarType::Float> tf_float;
345364 // clang-format off
0 commit comments