Skip to content

Commit b9f418f

Browse files
committed
Addition of a test for stride in locally_connected2d_layer
1 parent 663860d commit b9f418f

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

test/test_locally_connected2d_layer.f90

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ program test_locally_connected2d_layer
5858
select type(this_layer => input_layer % p); type is(input2d_layer)
5959
call this_layer % set(sample_input)
6060
end select
61+
deallocate(sample_input)
6162

6263
call locally_connected_1d_layer % forward(input_layer)
6364
call locally_connected_1d_layer % get_output(output)
@@ -67,11 +68,33 @@ program test_locally_connected2d_layer
6768
write(stderr, '(a)') 'locally_connected2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
6869
end if
6970

71+
! Minimal locally_connected_1d layer: 1 channel, 3x3 pixel image, stride = 3;
72+
allocate(sample_input(1, 17))
73+
sample_input = 0
74+
75+
input_layer = input(1, 17)
76+
locally_connected_1d_layer = locally_connected(filters, kernel_size, stride = 3)
77+
call locally_connected_1d_layer % init(input_layer)
78+
79+
select type(this_layer => input_layer % p); type is(input2d_layer)
80+
call this_layer % set(sample_input)
81+
end select
82+
deallocate(sample_input)
83+
84+
call locally_connected_1d_layer % forward(input_layer)
85+
call locally_connected_1d_layer % get_output(output)
86+
87+
if (.not. all(abs(output) < tolerance)) then
88+
ok = .false.
89+
write(stderr, '(a)') 'locally_connected2d layer with zero input and sigmoid function must forward to all 0.5.. failed'
90+
end if
91+
92+
!Final
7093
if (ok) then
7194
print '(a)', 'test_locally_connected2d_layer: All tests passed.'
7295
else
7396
write(stderr, '(a)') 'test_locally_connected2d_layer: One or more tests failed.'
7497
stop 1
7598
end if
76-
99+
77100
end program test_locally_connected2d_layer

0 commit comments

Comments
 (0)