Skip to content

Commit bed0f51

Browse files
authored
More concise testing (#237)
* More concise testing with tuff * Success reporing consistent with existing tests * Add tuff to cmake build
1 parent 35528d3 commit bed0f51

File tree

4 files changed

+131
-91
lines changed

4 files changed

+131
-91
lines changed

test/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_library(tuff tuff.f90)
2+
13
foreach(execid
24
input1d_layer
35
input2d_layer
@@ -27,7 +29,7 @@ foreach(execid
2729
metrics
2830
)
2931
add_executable(test_${execid} test_${execid}.f90)
30-
target_link_libraries(test_${execid} PRIVATE neural-fortran ${LIBS})
32+
target_link_libraries(test_${execid} PRIVATE tuff neural-fortran ${LIBS})
3133

3234
add_test(NAME test_${execid} COMMAND test_${execid})
3335
endforeach()

test/test_dense_layer.f90

Lines changed: 17 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,23 @@
11
program test_dense_layer
2-
use iso_fortran_env, only: stderr => error_unit
3-
use nf, only: dense, layer
4-
use nf_activation, only: relu
2+
use nf, only: dense, layer, relu
3+
use tuff, only: test, test_result
54
implicit none
6-
type(layer) :: layer1, layer2
7-
logical :: ok = .true.
5+
type(layer) :: layer1, layer2, layer3
6+
type(test_result) :: tests
87

98
layer1 = dense(10)
10-
11-
if (.not. layer1 % name == 'dense') then
12-
ok = .false.
13-
write(stderr, '(a)') 'dense layer has its name set correctly.. failed'
14-
end if
15-
16-
if (.not. all(layer1 % layer_shape == [10])) then
17-
ok = .false.
18-
write(stderr, '(a)') 'dense layer is created with requested size.. failed'
19-
end if
20-
21-
if (layer1 % initialized) then
22-
ok = .false.
23-
write(stderr, '(a)') 'dense layer should not be marked as initialized yet.. failed'
24-
end if
25-
26-
if (.not. layer1 % activation == 'sigmoid') then
27-
ok = .false.
28-
write(stderr, '(a)') 'dense layer is defaults to sigmoid activation.. failed'
29-
end if
30-
31-
layer1 = dense(10, activation=relu())
32-
33-
if (.not. layer1 % activation == 'relu') then
34-
ok = .false.
35-
write(stderr, '(a)') 'dense layer is created with the specified activation.. failed'
36-
end if
37-
38-
layer2 = dense(20)
39-
call layer2 % init(layer1)
40-
41-
if (.not. layer2 % initialized) then
42-
ok = .false.
43-
write(stderr, '(a)') 'dense layer should now be marked as initialized.. failed'
44-
end if
45-
46-
if (.not. all(layer2 % input_layer_shape == [10])) then
47-
ok = .false.
48-
write(stderr, '(a)') 'dense layer should have a correct input layer shape.. failed'
49-
end if
50-
51-
if (ok) then
52-
print '(a)', 'test_dense_layer: All tests passed.'
53-
else
54-
write(stderr, '(a)') 'test_dense_layer: One or more tests failed.'
55-
stop 1
56-
end if
9+
layer2 = dense(10, activation=relu())
10+
layer3 = dense(20)
11+
call layer3 % init(layer1)
12+
13+
tests = test("test_dense_layer", [ &
14+
test("layer name is set", layer1 % name == 'dense'), &
15+
test("layer shape is correct", all(layer1 % layer_shape == [10])), &
16+
test("layer is initialized", layer3 % initialized), &
17+
test("layer's default activation is sigmoid", layer1 % activation == 'sigmoid'), &
18+
test("user set activation works", layer2 % activation == 'relu'), &
19+
test("layer initialized after init", layer3 % initialized), &
20+
test("layer input shape is set after init", all(layer3 % input_layer_shape == [10])) &
21+
])
5722

5823
end program test_dense_layer

test/test_dense_network.f90

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,40 @@
11
program test_dense_network
22
use iso_fortran_env, only: stderr => error_unit
3-
use nf, only: dense, input, network
4-
use nf_optimizers, only: sgd
3+
use nf, only: dense, input, network, sgd
4+
use tuff, only: test, test_result
55
implicit none
66
type(network) :: net
7-
logical :: ok = .true.
7+
type(test_result) :: tests
88

99
! Minimal 2-layer network
1010
net = network([ &
1111
input(1), &
1212
dense(1) &
1313
])
1414

15-
if (.not. size(net % layers) == 2) then
16-
write(stderr, '(a)') 'dense network should have 2 layers.. failed'
17-
ok = .false.
18-
end if
15+
tests = test("test_dense_network", [ &
16+
test("network has 2 layers", size(net % layers) == 2), &
17+
test("network predicts 0.5 for input 0", all(net % predict([0.]) == 0.5)), &
18+
test(simple_training), &
19+
test(larger_network_size) &
20+
])
1921

20-
if (.not. all(net % predict([0.]) == 0.5)) then
21-
write(stderr, '(a)') &
22-
'dense network should output exactly 0.5 for input 0.. failed'
23-
ok = .false.
24-
end if
22+
contains
2523

26-
training: block
24+
type(test_result) function simple_training() result(res)
2725
real :: x(1), y(1)
2826
real :: tolerance = 1e-3
2927
integer :: n
30-
integer, parameter :: num_iterations = 1000
28+
integer, parameter :: num_iterations = 1000
29+
type(network) :: net
30+
31+
res % name = 'simple training'
32+
33+
! Minimal 2-layer network
34+
net = network([ &
35+
input(1), &
36+
dense(1) &
37+
])
3138

3239
x = [0.123]
3340
y = [0.765]
@@ -39,32 +46,25 @@ program test_dense_network
3946
if (all(abs(net % predict(x) - y) < tolerance)) exit
4047
end do
4148

42-
if (.not. n <= num_iterations) then
43-
write(stderr, '(a)') &
44-
'dense network should converge in simple training.. failed'
45-
ok = .false.
46-
end if
49+
res % ok = n <= num_iterations
4750

48-
end block training
51+
end function simple_training
4952

50-
! A bit larger multi-layer network
51-
net = network([ &
52-
input(784), &
53-
dense(30), &
54-
dense(20), &
55-
dense(10) &
56-
])
53+
type(test_result) function larger_network_size() result(res)
54+
type(network) :: net
55+
56+
res % name = 'larger network training'
57+
58+
! A bit larger multi-layer network
59+
net = network([ &
60+
input(784), &
61+
dense(30), &
62+
dense(20), &
63+
dense(10) &
64+
])
5765

58-
if (.not. size(net % layers) == 4) then
59-
write(stderr, '(a)') 'dense network should have 4 layers.. failed'
60-
ok = .false.
61-
end if
66+
res % ok = size(net % layers) == 4
6267

63-
if (ok) then
64-
print '(a)', 'test_dense_network: All tests passed.'
65-
else
66-
write(stderr, '(a)') 'test_dense_network: One or more tests failed.'
67-
stop 1
68-
end if
68+
end function larger_network_size
6969

70-
end program test_dense_network
70+
end program test_dense_network

test/tuff.f90

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module tuff
2+
! Testing Unframework for Fortran (TUFF)
3+
use iso_fortran_env, only: stderr => error_unit, stdout => output_unit
4+
implicit none
5+
6+
private
7+
public :: test, test_result
8+
9+
type :: test_result
10+
character(:), allocatable :: name
11+
logical :: ok = .true.
12+
real :: elapsed = 0.
13+
end type test_result
14+
15+
interface test
16+
module procedure test_logical, test_func, test_array
17+
end interface test
18+
19+
abstract interface
20+
function func() result(res)
21+
import :: test_result
22+
type(test_result) :: res
23+
end function func
24+
end interface
25+
26+
contains
27+
28+
type(test_result) function test_logical(name, cond) result(res)
29+
! Test a single logical expression.
30+
character(*), intent(in) :: name
31+
logical, intent(in) :: cond
32+
res % name = name
33+
res % ok = .true.
34+
res % elapsed = 0.
35+
if (.not. cond) then
36+
write(stderr, '(a)') 'Test ' // trim(name) // ' failed.'
37+
res % ok = .false.
38+
end if
39+
end function test_logical
40+
41+
42+
type(test_result) function test_func(f) result(res)
43+
! Test a user-provided function f that returns a test_result.
44+
! f is responsible for setting the test name and the ok field.
45+
procedure(func) :: f
46+
real :: t1, t2
47+
res % name = ''
48+
call cpu_time(t1)
49+
res = f()
50+
call cpu_time(t2)
51+
res % elapsed = t2 - t1
52+
if (len_trim(res % name) == 0) res % name = 'Anonymous test'
53+
if (.not. res % ok) then
54+
write(stderr, '(a, f6.3)') 'Test failed: ' // trim(res % name)
55+
end if
56+
end function test_func
57+
58+
59+
type(test_result) function test_array(name, tests) result(suite)
60+
! Test a suite of tests, each of which is a test_result.
61+
character(*), intent(in) :: name
62+
type(test_result), intent(in) :: tests(:)
63+
suite % ok = all(tests % ok)
64+
suite % elapsed = sum(tests % elapsed)
65+
if (suite % ok) then
66+
write(stdout, '(a)') trim(name) // ": All tests passed."
67+
else
68+
write(stderr, '(i0,a,i0,a)') count(.not. tests % ok), '/', size(tests), &
69+
" tests failed in suite: " // trim(name)
70+
end if
71+
end function test_array
72+
73+
end module tuff

0 commit comments

Comments
 (0)