|
5 | 5 | use nf_dense_layer, only: dense_layer |
6 | 6 | use nf_flatten_layer, only: flatten_layer |
7 | 7 | use nf_input1d_layer, only: input1d_layer |
| 8 | + use nf_input2d_layer, only: input2d_layer |
8 | 9 | use nf_input3d_layer, only: input3d_layer |
9 | 10 | use nf_maxpool2d_layer, only: maxpool2d_layer |
10 | 11 | use nf_reshape_layer, only: reshape3d_layer |
@@ -81,16 +82,28 @@ module function input1d(layer_size) result(res) |
81 | 82 | end function input1d |
82 | 83 |
|
83 | 84 |
|
84 | | - module function input3d(layer_shape) result(res) |
85 | | - integer, intent(in) :: layer_shape(3) |
| 85 | + module function input2d(dim1, dim2) result(res) |
| 86 | + integer, intent(in) :: dim1, dim2 |
86 | 87 | type(layer) :: res |
87 | 88 | res % name = 'input' |
88 | | - res % layer_shape = layer_shape |
| 89 | + res % layer_shape = [dim1, dim2] |
89 | 90 | res % input_layer_shape = [integer ::] |
90 | | - allocate(res % p, source=input3d_layer(layer_shape)) |
| 91 | + allocate(res % p, source=input2d_layer([dim1, dim2])) |
| 92 | + res % initialized = .true. |
| 93 | + end function input2d |
| 94 | + |
| 95 | + |
| 96 | + module function input3d(dim1, dim2, dim3) result(res) |
| 97 | + integer, intent(in) :: dim1, dim2, dim3 |
| 98 | + type(layer) :: res |
| 99 | + res % name = 'input' |
| 100 | + res % layer_shape = [dim1, dim2, dim3] |
| 101 | + res % input_layer_shape = [integer ::] |
| 102 | + allocate(res % p, source=input3d_layer([dim1, dim2, dim3])) |
91 | 103 | res % initialized = .true. |
92 | 104 | end function input3d |
93 | 105 |
|
| 106 | + |
94 | 107 | module function maxpool2d(pool_size, stride) result(res) |
95 | 108 | integer, intent(in) :: pool_size |
96 | 109 | integer, intent(in), optional :: stride |
@@ -119,6 +132,7 @@ module function maxpool2d(pool_size, stride) result(res) |
119 | 132 |
|
120 | 133 | end function maxpool2d |
121 | 134 |
|
| 135 | + |
122 | 136 | module function reshape(output_shape) result(res) |
123 | 137 | integer, intent(in) :: output_shape(:) |
124 | 138 | type(layer) :: res |
|
0 commit comments