Skip to content

Commit e86b6b2

Browse files
Alvaro-Kothepytorchmergebot
authored andcommitted
Add tests to check pretty print when padding is a string in C++ API (pytorch#153126)
Currently there are no tests to verify the behaviour of pretty print when padding is `torch::kSame` or `torch::kValid`. This PR just adds this tests to check for future regressions. Pull Request resolved: pytorch#153126 Approved by: https://github.com/Skylion007
1 parent d36261d commit e86b6b2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

test/cpp/api/modules.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4591,6 +4591,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
45914591
ASSERT_EQ(
45924592
c10::str(Conv1d(3, 4, 5)),
45934593
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1)");
4594+
{
4595+
auto options = Conv1dOptions(3, 4, 5);
4596+
ASSERT_EQ(
4597+
c10::str(Conv1d(options.padding(torch::kSame))),
4598+
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1, padding='same')");
4599+
ASSERT_EQ(
4600+
c10::str(Conv1d(options.padding(torch::kValid))),
4601+
"torch::nn::Conv1d(3, 4, kernel_size=5, stride=1, padding='valid')");
4602+
}
45944603

45954604
ASSERT_EQ(
45964605
c10::str(Conv2d(3, 4, 5)),
@@ -4605,6 +4614,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
46054614
c10::str(Conv2d(options)),
46064615
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 2])");
46074616
}
4617+
{
4618+
auto options = Conv2dOptions(3, 4, std::vector<int64_t>{5, 6});
4619+
ASSERT_EQ(
4620+
c10::str(Conv2d(options.padding(torch::kSame))),
4621+
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 1], padding='same')");
4622+
ASSERT_EQ(
4623+
c10::str(Conv2d(options.padding(torch::kValid))),
4624+
"torch::nn::Conv2d(3, 4, kernel_size=[5, 6], stride=[1, 1], padding='valid')");
4625+
}
46084626

46094627
ASSERT_EQ(
46104628
c10::str(Conv3d(4, 4, std::vector<int64_t>{5, 6, 7})),
@@ -4630,6 +4648,15 @@ TEST_F(ModulesTest, PrettyPrintConv) {
46304648
"bias=false, "
46314649
"padding_mode=kCircular)");
46324650
}
4651+
{
4652+
auto options = Conv3dOptions(3, 4, std::vector<int64_t>{5, 6, 7});
4653+
ASSERT_EQ(
4654+
c10::str(Conv3d(options.padding(torch::kSame))),
4655+
"torch::nn::Conv3d(3, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1], padding='same')");
4656+
ASSERT_EQ(
4657+
c10::str(Conv3d(options.padding(torch::kValid))),
4658+
"torch::nn::Conv3d(3, 4, kernel_size=[5, 6, 7], stride=[1, 1, 1], padding='valid')");
4659+
}
46334660
}
46344661

46354662
TEST_F(ModulesTest, PrettyPrintConvTranspose) {

0 commit comments

Comments
 (0)