Skip to content

Commit ae73737

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add upsample_bilinar2d_aa correctness test and enable ATen mode
Differential Revision: D81720498
1 parent f4ec01a commit ae73737

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

kernels/test/op_upsample_bilinear2d_aa_test.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,34 @@ TEST_F(OpUpsampleBilinear2dAAOutTest, TestPrecisionConsistency) {
625625
EXPECT_EQ(out1_data[i], out2_data[i]);
626626
}
627627
}
628+
629+
TEST_F(OpUpsampleBilinear2dAAOutTest, SmokeTestCorrectness) {
630+
TensorFactory<ScalarType::Float> tf;
631+
632+
Tensor input = tf.make(
633+
{1, 1, 8, 1},
634+
{-98.5, 49.875, 17.125, -46.5, 10.625, -95.875, -3.875, -4.625});
635+
636+
int64_t output_size_data[2] = {3, 8};
637+
ArrayRef<int64_t> output_size(output_size_data, 2);
638+
639+
Tensor out = tf.zeros({1, 1, 3, 8});
640+
641+
// clang-format off
642+
Tensor expected = tf.make({1, 1, 3, 8}, {
643+
-8.4408, -8.4408, -8.4408, -8.4408, -8.4408, -8.4408, -8.4408, -8.4408,
644+
-23.1339, -23.1339, -23.1339, -23.1339, -23.1339, -23.1339, -23.1339, -23.1339,
645+
-24.7368, -24.7368, -24.7368, -24.7368, -24.7368, -24.7368, -24.7368, -24.7368
646+
});
647+
// clang-format on
648+
649+
op_upsample_bilinear2d_aa_out(
650+
input,
651+
output_size,
652+
/*align_corners=*/false,
653+
std::nullopt,
654+
std::nullopt,
655+
out);
656+
657+
EXPECT_TENSOR_CLOSE(out, expected);
658+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def define_common_targets():
335335
_common_op_test("op_unfold_copy_test", ["aten", "portable"])
336336
_common_op_test("op_unsqueeze_copy_test", ["aten", "portable"])
337337
_common_op_test("op_upsample_bilinear2d_test", ["aten", "portable"])
338-
_common_op_test("op_upsample_bilinear2d_aa_test", ["portable"])
338+
_common_op_test("op_upsample_bilinear2d_aa_test", ["aten", "portable"])
339339
_common_op_test("op_upsample_nearest2d_test", ["aten", "portable"])
340340
_common_op_test("op_var_test", ["aten", "portable"])
341341
_common_op_test("op_view_as_real_copy_test", ["aten", "portable"])

0 commit comments

Comments
 (0)