Skip to content

Commit b964d3a

Browse files
committed
Added few tests for torch
1 parent 5d9808b commit b964d3a

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

modules/dnn/samples/torch_enet.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const String keys =
2323
"{c_names c || path to file with classnames for channels (optional, categories.txt) }"
2424
"{result r || path to save output blob (optional, binary format, NCHW order) }"
2525
"{show s || whether to show all output channels or not}"
26+
"{o_blob || output blob's name. If empty, last blob's name in net is used}"
2627
;
2728

2829
std::vector<String> readClassNames(const char *filename);
@@ -112,7 +113,13 @@ int main(int argc, char **argv)
112113

113114
//! [Gather output]
114115

115-
dnn::Blob prob = net.getBlob(net.getLayerNames().back()); //gather output of "prob" layer
116+
String oBlob = net.getLayerNames().back();
117+
if (!parser.get<String>("o_blob").empty())
118+
{
119+
oBlob = parser.get<String>("o_blob");
120+
}
121+
122+
dnn::Blob prob = net.getBlob(oBlob); //gather output of "prob" layer
116123

117124
Mat& result = prob.matRef();
118125

modules/dnn/src/layers/max_unpooling_layer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ void MaxUnpoolLayerImpl::allocate(const std::vector<Blob*> &inputs, std::vector<
2828
outShape[2] = outSize.height;
2929
outShape[3] = outSize.width;
3030

31+
CV_Assert(inputs[0]->total() == inputs[1]->total());
32+
3133
outputs.resize(1);
3234
outputs[0].create(outShape);
3335
}

modules/dnn/test/test_torch_importer.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ TEST(Torch_Importer, simple_read)
7272
importer->populateNet(net);
7373
}
7474

75-
static void runTorchNet(String prefix, String outLayerName, bool isBinary)
75+
static void runTorchNet(String prefix, String outLayerName = "",
76+
bool check2ndBlob = false, bool isBinary = false)
7677
{
7778
String suffix = (isBinary) ? ".dat" : ".txt";
7879

@@ -92,52 +93,69 @@ static void runTorchNet(String prefix, String outLayerName, bool isBinary)
9293
Blob out = net.getBlob(outLayerName);
9394

9495
normAssert(outRef, out);
96+
97+
if (check2ndBlob)
98+
{
99+
Blob out2 = net.getBlob(outLayerName + ".1");
100+
Blob ref2 = readTorchBlob(_tf(prefix + "_output_2" + suffix), isBinary);
101+
normAssert(out2, ref2);
102+
}
95103
}
96104

97105
TEST(Torch_Importer, run_convolution)
98106
{
99-
runTorchNet("net_conv", "l1_Convolution", false);
107+
runTorchNet("net_conv");
100108
}
101109

102110
TEST(Torch_Importer, run_pool_max)
103111
{
104-
runTorchNet("net_pool_max", "l1_Pooling", false);
112+
runTorchNet("net_pool_max", "", true);
105113
}
106114

107115
TEST(Torch_Importer, run_pool_ave)
108116
{
109-
runTorchNet("net_pool_ave", "l1_Pooling", false);
117+
runTorchNet("net_pool_ave");
110118
}
111119

112120
TEST(Torch_Importer, run_reshape)
113121
{
114-
runTorchNet("net_reshape", "l1_Reshape", false);
115-
runTorchNet("net_reshape_batch", "l1_Reshape", false);
122+
runTorchNet("net_reshape");
123+
runTorchNet("net_reshape_batch");
116124
}
117125

118126
TEST(Torch_Importer, run_linear)
119127
{
120-
runTorchNet("net_linear_2d", "l1_InnerProduct", false);
128+
runTorchNet("net_linear_2d");
121129
}
122130

123131
TEST(Torch_Importer, run_paralel)
124132
{
125-
runTorchNet("net_parallel", "l2_torchMerge", false);
133+
runTorchNet("net_parallel", "l2_torchMerge");
126134
}
127135

128136
TEST(Torch_Importer, run_concat)
129137
{
130-
runTorchNet("net_concat", "l2_torchMerge", false);
138+
runTorchNet("net_concat", "l2_torchMerge");
131139
}
132140

133141
TEST(Torch_Importer, run_deconv)
134142
{
135-
runTorchNet("net_deconv", "", false);
143+
runTorchNet("net_deconv");
136144
}
137145

138146
TEST(Torch_Importer, run_batch_norm)
139147
{
140-
runTorchNet("net_batch_norm", "", false);
148+
runTorchNet("net_batch_norm");
149+
}
150+
151+
TEST(Torch_Importer, net_prelu)
152+
{
153+
runTorchNet("net_prelu");
154+
}
155+
156+
TEST(Torch_Importer, net_cadd_table)
157+
{
158+
runTorchNet("net_cadd_table");
141159
}
142160

143161
#if defined(ENABLE_TORCH_ENET_TESTS)

modules/dnn/testdata/dnn/torch/torch_gen_test_data.lua

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ function save(net, input, label)
2727
torch.save(label .. '_input.txt', input, 'ascii')
2828
--torch.save(label .. '_output.dat', output)
2929
torch.save(label .. '_output.txt', output, 'ascii')
30+
31+
return net
3032
end
3133

3234
local net_simple = nn.Sequential()
@@ -38,7 +40,8 @@ save(net_simple, torch.Tensor(2, 3, 25, 35), 'net_simple')
3840

3941
local net_pool_max = nn.Sequential()
4042
net_pool_max:add(nn.SpatialMaxPooling(4,5, 3,2, 1,2):ceil()) --TODO: add ceil and floor modes
41-
save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max')
43+
local net = save(net_pool_max, torch.rand(2, 3, 50, 30), 'net_pool_max')
44+
torch.save('net_pool_max_output_2.txt', net.modules[1].indices - 1, 'ascii')
4245

4346
local net_pool_ave = nn.Sequential()
4447
net_pool_ave:add(nn.SpatialAveragePooling(4,5, 2,1, 1,2))
@@ -74,5 +77,15 @@ net_deconv:add(nn.SpatialFullConvolution(3, 9, 4, 5, 1, 2, 0, 1, 0, 1))
7477
save(net_deconv, torch.rand(2, 3, 4, 3) - 0.5, 'net_deconv')
7578

7679
local net_batch_norm = nn.Sequential()
77-
net_batch_norm:add(nn.SpatialBatchNormalization(3))
78-
save(net_batch_norm, torch.rand(1, 3, 4, 3) - 0.5, 'net_batch_norm')
80+
net_batch_norm:add(nn.SpatialBatchNormalization(4, 1e-3))
81+
save(net_batch_norm, torch.rand(1, 4, 5, 6) - 0.5, 'net_batch_norm')
82+
83+
local net_prelu = nn.Sequential()
84+
net_prelu:add(nn.PReLU(5))
85+
save(net_prelu, torch.rand(1, 5, 40, 50) - 0.5, 'net_prelu')
86+
87+
local net_cadd_table = nn.Sequential()
88+
local sum = nn.ConcatTable()
89+
sum:add(nn.Identity()):add(nn.Identity())
90+
net_cadd_table:add(sum):add(nn.CAddTable())
91+
save(net_cadd_table, torch.rand(1, 5, 40, 50) - 0.5, 'net_cadd_table')

0 commit comments

Comments
 (0)