Skip to content

Commit e1c226f

Browse files
committed
linear2d_layer_no_bias: add optional argument to disable biases
1 parent e68e6c2 commit e1c226f

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

src/nf/nf_layer_constructors.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,14 @@ module function reshape(output_shape) result(res)
213213
!! Resulting layer instance
214214
end function reshape
215215

216-
module function linear2d(out_features) result(res)
216+
module function linear2d(out_features, biases) result(res)
217217
!! Rank-2 (sequence_length, out_features) linear layer constructor.
218218
!! sequence_length is determined at layer initialization, based on the
219219
!! output shape of the previous layer.
220220
integer, intent(in) :: out_features
221221
!! Number of output features
222+
logical, optional :: biases
223+
!! Whether to use biases or not
222224
type(layer) :: res
223225
!! Resulting layer instance
224226
end function linear2d

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ module function reshape(output_shape) result(res)
163163
end function reshape
164164

165165

166-
module function linear2d(out_features) result(res)
166+
module function linear2d(out_features, biases) result(res)
167167
integer, intent(in) :: out_features
168+
logical, optional :: biases
168169
type(layer) :: res
169170

170171
res % name = 'linear2d'
171-
allocate(res % p, source=linear2d_layer(out_features))
172+
allocate(res % p, source=linear2d_layer(out_features, biases))
172173

173174
end function linear2d
174175

src/nf/nf_linear2d_layer.f90

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ module nf_linear2d_layer
99
public :: linear2d_layer
1010

1111
type, extends(base_layer) :: linear2d_layer
12-
integer :: sequence_length, in_features, out_features, batch_size
12+
integer :: sequence_length, in_features, out_features
13+
logical :: use_biases
1314

1415
real, allocatable :: weights(:,:)
1516
real, allocatable :: biases(:)
@@ -31,8 +32,9 @@ module nf_linear2d_layer
3132
end type linear2d_layer
3233

3334
interface linear2d_layer
34-
module function linear2d_layer_cons(out_features) result(res)
35+
module function linear2d_layer_cons(out_features, biases) result(res)
3536
integer, intent(in) :: out_features
37+
logical, optional, intent(in) :: biases
3638
type(linear2d_layer) :: res
3739
end function linear2d_layer_cons
3840
end interface linear2d_layer

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55

66
contains
77

8-
module function linear2d_layer_cons(out_features) result(res)
8+
module function linear2d_layer_cons(out_features, biases) result(res)
99
integer, intent(in) :: out_features
10+
logical, optional, intent(in) :: biases
1011
type(linear2d_layer) :: res
1112

1213
res % out_features = out_features
14+
if (present(biases)) then
15+
res % use_biases = biases
16+
else
17+
res % use_biases = .true.
18+
end if
1319

1420
end function linear2d_layer_cons
1521

@@ -36,8 +42,10 @@ module subroutine init(self, input_shape)
3642
allocate(self % dw(self % in_features, self % out_features))
3743
self % dw = 0
3844

39-
allocate(self % db(self % out_features))
40-
self % db = 0
45+
if (self % use_biases) then
46+
allocate(self % db(self % out_features))
47+
self % db = 0
48+
end if
4149

4250
end subroutine init
4351

@@ -48,9 +56,11 @@ pure module subroutine forward(self, input)
4856
integer :: i
4957

5058
self % output(:,:) = matmul(input(:,:), self % weights)
51-
do concurrent(i = 1:self % sequence_length)
52-
self % output(i,:) = self % output(i,:) + self % biases
53-
end do
59+
if (self % use_biases) then
60+
do concurrent(i = 1:self % sequence_length)
61+
self % output(i,:) = self % output(i,:) + self % biases
62+
end do
63+
end if
5464

5565
end subroutine forward
5666

@@ -64,7 +74,9 @@ pure module subroutine backward(self, input, gradient)
6474
integer :: i
6575

6676
self % dw = self % dw + matmul(transpose(input(:,:)), gradient(:,:))
67-
self % db = self % db + sum(gradient(:,:), 1)
77+
if (self % use_biases) then
78+
self % db = self % db + sum(gradient(:,:), 1)
79+
end if
6880
self % gradient(:,:) = matmul(gradient(:,:), transpose(self % weights))
6981
end subroutine backward
7082

@@ -74,7 +86,10 @@ pure module function get_num_params(self) result(num_params)
7486
integer :: num_params
7587

7688
! Number of weights times number of biases
77-
num_params = self % in_features * self % out_features + self % out_features
89+
num_params = self % in_features * self % out_features
90+
if (self % use_biases) then
91+
num_params = num_params + self % out_features
92+
end if
7893

7994
end function get_num_params
8095

@@ -87,10 +102,14 @@ module function get_params(self) result(params)
87102

88103
w_(1: product(shape(self % weights))) => self % weights
89104

90-
params = [ &
91-
w_, &
92-
self % biases &
93-
]
105+
if (self % use_biases) then
106+
params = [ &
107+
w_, &
108+
self % biases &
109+
]
110+
else
111+
params = w_
112+
end if
94113

95114
end function get_params
96115

@@ -103,10 +122,14 @@ module function get_gradients(self) result(gradients)
103122

104123
dw_(1: product(shape(self % dw))) => self % dw
105124

106-
gradients = [ &
107-
dw_, &
108-
self % db &
109-
]
125+
if (self % use_biases) then
126+
gradients = [ &
127+
dw_, &
128+
self % db &
129+
]
130+
else
131+
gradients = dw_
132+
end if
110133

111134
end function get_gradients
112135

@@ -127,10 +150,12 @@ module subroutine set_params(self, params)
127150
p_(1:self % in_features, 1:self % out_features) => params(1 : n)
128151
self % weights = p_
129152

130-
! reshape the biases
131-
self % biases = params(n + 1 : n + self % out_features)
153+
if (self % use_biases) then
154+
! reshape the biases
155+
self % biases = params(n + 1 : n + self % out_features)
156+
end if
132157
end associate
133158

134159
end subroutine set_params
135160

136-
end submodule nf_linear2d_layer_submodule
161+
end submodule nf_linear2d_layer_submodule

0 commit comments

Comments
 (0)