Skip to content

Commit 7d7eded

Browse files
committed
Add Input2d layer by redesigning the input parameters to input layer constructors
1 parent d516437 commit 7d7eded

11 files changed

+129
-25
lines changed

src/nf/nf_input2d_layer.f90

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module nf_input2d_layer
2+
3+
!! This module provides the `input2d_layer` type.
4+
5+
use nf_base_layer, only: base_layer
6+
implicit none
7+
8+
private
9+
public :: input2d_layer
10+
11+
type, extends(base_layer) :: input2d_layer
12+
real, allocatable :: output(:,:)
13+
contains
14+
procedure :: init
15+
procedure :: set
16+
end type input2d_layer
17+
18+
interface input2d_layer
19+
pure module function input2d_layer_cons(output_shape) result(res)
20+
!! Create a new instance of the 2-d input layer.
21+
!! Only used internally by the `layer % init` method.
22+
integer, intent(in) :: output_shape(2)
23+
!! Shape of the input layer
24+
type(input2d_layer) :: res
25+
!! 2-d input layer instance
26+
end function input2d_layer_cons
27+
end interface input2d_layer
28+
29+
interface
30+
31+
module subroutine init(self, input_shape)
32+
!! Only here to satisfy the language rules
33+
!! about deferred methods of abstract types.
34+
!! This method does nothing for this type and should not be called.
35+
class(input2d_layer), intent(in out) :: self
36+
integer, intent(in) :: input_shape(:)
37+
end subroutine init
38+
39+
pure module subroutine set(self, values)
40+
class(input2d_layer), intent(in out) :: self
41+
!! Layer instance
42+
real, intent(in) :: values(:,:)
43+
!! Values to set
44+
end subroutine set
45+
46+
end interface
47+
48+
end module nf_input2d_layer
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
submodule(nf_input2d_layer) nf_input2d_layer_submodule
2+
implicit none
3+
contains
4+
5+
pure module function input2d_layer_cons(output_shape) result(res)
6+
integer, intent(in) :: output_shape(2)
7+
type(input2d_layer) :: res
8+
allocate(res % output(output_shape(1), output_shape(2)))
9+
res % output = 0
10+
end function input2d_layer_cons
11+
12+
module subroutine init(self, input_shape)
13+
class(input2d_layer), intent(in out) :: self
14+
integer, intent(in) :: input_shape(:)
15+
end subroutine init
16+
17+
pure module subroutine set(self, values)
18+
class(input2d_layer), intent(in out) :: self
19+
real, intent(in) :: values(:,:)
20+
self % output = values
21+
end subroutine set
22+
23+
end submodule nf_input2d_layer_submodule

src/nf/nf_layer_constructors.f90

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ module function input1d(layer_size) result(res)
3535
!! Resulting layer instance
3636
end function input1d
3737

38-
module function input3d(layer_shape) result(res)
39-
!! 3-d input layer constructor.
38+
module function input2d(dim1, dim2) result(res)
39+
!! 2-d input layer constructor.
4040
!!
41-
!! This layer is for inputting 3-d data to the network.
41+
!! This layer is for inputting 2-d data to the network.
4242
!! Currently, this layer must be followed by a conv2d layer.
4343
!! An input layer must be the first layer in the network.
4444
!!
@@ -50,10 +50,29 @@ module function input3d(layer_shape) result(res)
5050
!! ```
5151
!! use nf, only :: input, layer
5252
!! type(layer) :: input_layer
53-
!! input_layer = input([28, 28, 1])
53+
!! input_layer = input(28, 28)
54+
!! ```
55+
integer, intent(in) :: dim1, dim2
56+
!! First and second dimension sizes
57+
type(layer) :: res
58+
!! Resulting layer instance
59+
end function input2d
60+
61+
module function input3d(dim1, dim2, dim3) result(res)
62+
!! 3-d input layer constructor.
63+
!!
64+
!! This is a specific function that is available
65+
!! under a generic name `input`.
66+
!!
67+
!! Example:
68+
!!
69+
!! ```
70+
!! use nf, only :: input, layer
71+
!! type(layer) :: input_layer
72+
!! input_layer = input(28, 28, 1)
5473
!! ```
55-
integer, intent(in) :: layer_shape(3)
56-
!! Shape of the input layer
74+
integer, intent(in) :: dim1, dim2, dim3
75+
!! First, second and third dimension sizes
5776
type(layer) :: res
5877
!! Resulting layer instance
5978
end function input3d

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
use nf_dense_layer, only: dense_layer
66
use nf_flatten_layer, only: flatten_layer
77
use nf_input1d_layer, only: input1d_layer
8+
use nf_input2d_layer, only: input2d_layer
89
use nf_input3d_layer, only: input3d_layer
910
use nf_maxpool2d_layer, only: maxpool2d_layer
1011
use nf_reshape_layer, only: reshape3d_layer
@@ -81,16 +82,28 @@ module function input1d(layer_size) result(res)
8182
end function input1d
8283

8384

84-
module function input3d(layer_shape) result(res)
85-
integer, intent(in) :: layer_shape(3)
85+
module function input2d(dim1, dim2) result(res)
86+
integer, intent(in) :: dim1, dim2
8687
type(layer) :: res
8788
res % name = 'input'
88-
res % layer_shape = layer_shape
89+
res % layer_shape = [dim1, dim2]
8990
res % input_layer_shape = [integer ::]
90-
allocate(res % p, source=input3d_layer(layer_shape))
91+
allocate(res % p, source=input2d_layer([dim1, dim2]))
92+
res % initialized = .true.
93+
end function input2d
94+
95+
96+
module function input3d(dim1, dim2, dim3) result(res)
97+
integer, intent(in) :: dim1, dim2, dim3
98+
type(layer) :: res
99+
res % name = 'input'
100+
res % layer_shape = [dim1, dim2, dim3]
101+
res % input_layer_shape = [integer ::]
102+
allocate(res % p, source=input3d_layer([dim1, dim2, dim3]))
91103
res % initialized = .true.
92104
end function input3d
93105

106+
94107
module function maxpool2d(pool_size, stride) result(res)
95108
integer, intent(in) :: pool_size
96109
integer, intent(in), optional :: stride
@@ -119,6 +132,7 @@ module function maxpool2d(pool_size, stride) result(res)
119132

120133
end function maxpool2d
121134

135+
122136
module function reshape(output_shape) result(res)
123137
integer, intent(in) :: output_shape(:)
124138
type(layer) :: res

test/test_conv2d_layer.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ program test_conv2d_layer
2929
write(stderr, '(a)') 'conv2d layer defaults to relu activation.. failed'
3030
end if
3131

32-
input_layer = input([3, 32, 32])
32+
input_layer = input(3, 32, 32)
3333
call conv_layer % init(input_layer)
3434

3535
if (.not. conv_layer % initialized) then
@@ -51,7 +51,7 @@ program test_conv2d_layer
5151
allocate(sample_input(1, 3, 3))
5252
sample_input = 0
5353

54-
input_layer = input([1, 3, 3])
54+
input_layer = input(1, 3, 3)
5555
conv_layer = conv2d(filters, kernel_size)
5656
call conv_layer % init(input_layer)
5757

test/test_conv2d_network.f90

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ program test_conv2d_network
1111

1212
! 3-layer convolutional network
1313
net = network([ &
14-
input([3, 32, 32]), &
14+
input(3, 32, 32), &
1515
conv2d(filters=16, kernel_size=3), &
1616
conv2d(filters=32, kernel_size=3) &
1717
])
@@ -48,7 +48,7 @@ program test_conv2d_network
4848
call random_number(sample_input)
4949

5050
cnn = network([ &
51-
input(shape(sample_input)), &
51+
input(1, 5, 5), &
5252
conv2d(filters=1, kernel_size=3), &
5353
conv2d(filters=1, kernel_size=3), &
5454
dense(1) &
@@ -84,7 +84,7 @@ program test_conv2d_network
8484
y = [0.1234567]
8585

8686
cnn = network([ &
87-
input(shape(x)), &
87+
input(1, 8, 8), &
8888
conv2d(filters=1, kernel_size=3), &
8989
maxpool2d(pool_size=2), &
9090
conv2d(filters=1, kernel_size=3), &
@@ -119,7 +119,7 @@ program test_conv2d_network
119119
y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789, 0.67890, 0.78901, 0.89012, 0.90123]
120120

121121
cnn = network([ &
122-
input(shape(x)), &
122+
input(1, 12, 12), &
123123
conv2d(filters=1, kernel_size=3), & ! 1x12x12 input, 1x10x10 output
124124
maxpool2d(pool_size=2), & ! 1x10x10 input, 1x5x5 output
125125
conv2d(filters=1, kernel_size=3), & ! 1x5x5 input, 1x3x3 output

test/test_flatten_layer.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ program test_flatten_layer
2525
write(stderr, '(a)') 'flatten layer is not initialized yet.. failed'
2626
end if
2727

28-
input_layer = input([1, 2, 2])
28+
input_layer = input(1, 2, 2)
2929
call test_layer % init(input_layer)
3030

3131
if (.not. test_layer % initialized) then
@@ -68,7 +68,7 @@ program test_flatten_layer
6868
end if
6969

7070
net = network([ &
71-
input([1, 28, 28]), &
71+
input(1, 28, 28), &
7272
flatten(), &
7373
dense(10) &
7474
])

test/test_get_set_network_params.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ program test_get_set_network_params
99

1010
! First test get_num_params()
1111
net = network([ &
12-
input([3, 5, 5]), & ! 5 x 5 image with 3 channels
12+
input(3, 5, 5), & ! 5 x 5 image with 3 channels
1313
conv2d(filters=2, kernel_size=3), & ! kernel shape [2, 3, 3, 3], output shape [2, 3, 3], 56 parameters total
1414
flatten(), &
1515
dense(4) & ! weights shape [72], biases shape [4], 76 parameters total
@@ -45,7 +45,7 @@ program test_get_set_network_params
4545

4646
! Finally, test set_params() and get_params() for a conv2d layer
4747
net = network([ &
48-
input([1, 3, 3]), &
48+
input(1, 3, 3), &
4949
conv2d(filters=1, kernel_size=3) &
5050
])
5151

test/test_input3d_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ program test_input3d_layer
1010
real, allocatable :: output(:,:,:)
1111
logical :: ok = .true.
1212

13-
test_layer = input([3, 32, 32])
13+
test_layer = input(3, 32, 32)
1414

1515
if (.not. test_layer % name == 'input') then
1616
ok = .false.

test/test_insert_flatten.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ program test_insert_flatten
99
logical :: ok = .true.
1010

1111
net = network([ &
12-
input([3, 32, 32]), &
12+
input(3, 32, 32), &
1313
dense(10) &
1414
])
1515

@@ -19,7 +19,7 @@ program test_insert_flatten
1919
end if
2020

2121
net = network([ &
22-
input([3, 32, 32]), &
22+
input(3, 32, 32), &
2323
conv2d(filters=1, kernel_size=3), &
2424
dense(10) &
2525
])
@@ -32,7 +32,7 @@ program test_insert_flatten
3232
end if
3333

3434
net = network([ &
35-
input([3, 32, 32]), &
35+
input(3, 32, 32), &
3636
conv2d(filters=1, kernel_size=3), &
3737
maxpool2d(pool_size=2, stride=2), &
3838
dense(10) &

0 commit comments

Comments
 (0)