Skip to content

Commit 703f802

Browse files
committed
Integration of backward pass for dropout
1 parent 3b5cc27 commit 703f802

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/nf/nf_layer_submodule.f90

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ pure module subroutine backward_1d(self, previous, gradient)
2525

2626
type is(dense_layer)
2727

28-
! Upstream layers permitted: input1d, dense, flatten
28+
! Upstream layers permitted: input1d, dense, dropout, flatten
2929
select type(prev_layer => previous % p)
3030
type is(input1d_layer)
3131
call this_layer % backward(prev_layer % output, gradient)
3232
type is(dense_layer)
3333
call this_layer % backward(prev_layer % output, gradient)
34+
type is(dropout_layer)
35+
call this_layer % backward(prev_layer % output, gradient)
3436
type is(flatten_layer)
3537
call this_layer % backward(prev_layer % output, gradient)
3638
end select
@@ -116,12 +118,14 @@ module subroutine forward(self, input)
116118

117119
type is(dense_layer)
118120

119-
! Upstream layers permitted: input1d, dense, flatten
121+
! Upstream layers permitted: input1d, dense, dropout, flatten
120122
select type(prev_layer => input % p)
121123
type is(input1d_layer)
122124
call this_layer % forward(prev_layer % output)
123125
type is(dense_layer)
124126
call this_layer % forward(prev_layer % output)
127+
type is(dropout_layer)
128+
call this_layer % forward(prev_layer % output)
125129
type is(flatten_layer)
126130
call this_layer % forward(prev_layer % output)
127131
end select
@@ -299,6 +303,8 @@ elemental module function get_num_params(self) result(num_params)
299303
num_params = 0
300304
type is (dense_layer)
301305
num_params = this_layer % get_num_params()
306+
type is (dropout_layer)
307+
num_params = size(this_layer % mask)
302308
type is (conv2d_layer)
303309
num_params = this_layer % get_num_params()
304310
type is (maxpool2d_layer)
@@ -324,6 +330,8 @@ module function get_params(self) result(params)
324330
! No parameters to get.
325331
type is (dense_layer)
326332
params = this_layer % get_params()
333+
type is (dropout_layer)
334+
! No parameters to get.
327335
type is (conv2d_layer)
328336
params = this_layer % get_params()
329337
type is (maxpool2d_layer)
@@ -349,6 +357,8 @@ module function get_gradients(self) result(gradients)
349357
! No gradients to get.
350358
type is (dense_layer)
351359
gradients = this_layer % get_gradients()
360+
type is (dropout_layer)
361+
! No gradients to get.
352362
type is (conv2d_layer)
353363
gradients = this_layer % get_gradients()
354364
type is (maxpool2d_layer)
@@ -396,6 +406,11 @@ module subroutine set_params(self, params)
396406
type is (dense_layer)
397407
call this_layer % set_params(params)
398408

409+
type is (dropout_layer)
410+
! No parameters to set.
411+
write(stderr, '(a)') 'Warning: calling set_params() ' &
412+
// 'on a zero-parameter layer; nothing to do.'
413+
399414
type is (conv2d_layer)
400415
call this_layer % set_params(params)
401416

src/nf/nf_network_submodule.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ module subroutine backward(self, output, loss)
135135
select type(next_layer => self % layers(n + 1) % p)
136136
type is(dense_layer)
137137
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
138+
type is(dropout_layer)
139+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
138140
type is(conv2d_layer)
139141
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
140142
type is(flatten_layer)

0 commit comments

Comments
 (0)