@@ -24,48 +24,22 @@ TEST(Pad, real) {
24
24
for (size_t imgSizeW : {5 , 32 , 96 }) {
25
25
VLOG (3 ) << " numSamples=" << numSamples << " channels=" << channels
26
26
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
27
-
28
- FunctionCompare compare (" Pad" ,
29
- FuncConfig ()
30
- .set (" cstart" , 2 )
31
- .set (" cend" , 3 )
32
- .set (" hstart" , 1 )
33
- .set (" hend" , 2 )
34
- .set (" wstart" , 3 )
35
- .set (" wend" , 2 ));
36
- TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
37
- TensorShape outDims{
38
- numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
39
- compare.addInputs (BufferArg (VALUE_TYPE_FLOAT, inDims));
40
- compare.addOutputs (BufferArg (VALUE_TYPE_FLOAT, outDims, ASSIGN_TO));
41
- compare.run ();
42
- }
43
- }
44
- }
45
- }
46
- }
47
-
48
- TEST (PadGrad, real) {
49
- for (size_t numSamples : {5 , 32 }) {
50
- for (size_t channels : {1 , 5 , 32 }) {
51
- for (size_t imgSizeH : {5 , 33 , 100 }) {
52
- for (size_t imgSizeW : {5 , 32 , 96 }) {
53
- VLOG (3 ) << " numSamples=" << numSamples << " channels=" << channels
54
- << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
55
- FunctionCompare compare (" PadGrad" ,
56
- FuncConfig ()
57
- .set (" cstart" , 2 )
58
- .set (" cend" , 3 )
59
- .set (" hstart" , 1 )
60
- .set (" hend" , 2 )
61
- .set (" wstart" , 3 )
62
- .set (" wend" , 2 ));
63
- TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
64
- TensorShape outDims{
65
- numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
66
- compare.addInputs (BufferArg (VALUE_TYPE_FLOAT, outDims));
67
- compare.addOutputs (BufferArg (VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
68
- compare.run ();
27
+ for (bool test_grad : {false , true }) {
28
+ FunctionCompare compare (
29
+ test_grad ? " PadGrad" : " Pad" ,
30
+ FuncConfig ()
31
+ .set <std::vector<uint32_t >>(" channel" , {2 , 3 })
32
+ .set <std::vector<uint32_t >>(" height" , {1 , 2 })
33
+ .set <std::vector<uint32_t >>(" width" , {3 , 2 }));
34
+ TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
35
+ TensorShape outDims{
36
+ numSamples, channels + 5 , imgSizeH + 3 , imgSizeW + 5 };
37
+ compare.addInputs (
38
+ BufferArg (VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
39
+ compare.addOutputs (BufferArg (
40
+ VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
41
+ compare.run ();
42
+ }
69
43
}
70
44
}
71
45
}
0 commit comments