55
66contains
77
8- module function linear2d_layer_cons (out_features ) result(res)
8+ module function linear2d_layer_cons (out_features , biases ) result(res)
99 integer , intent (in ) :: out_features
10+ logical , optional , intent (in ) :: biases
1011 type (linear2d_layer) :: res
1112
1213 res % out_features = out_features
14+ if (present (biases)) then
15+ res % use_biases = biases
16+ else
17+ res % use_biases = .true.
18+ end if
1319
1420 end function linear2d_layer_cons
1521
@@ -36,8 +42,10 @@ module subroutine init(self, input_shape)
3642 allocate (self % dw(self % in_features, self % out_features))
3743 self % dw = 0
3844
39- allocate (self % db(self % out_features))
40- self % db = 0
45+ if (self % use_biases) then
46+ allocate (self % db(self % out_features))
47+ self % db = 0
48+ end if
4149
4250 end subroutine init
4351
@@ -48,9 +56,11 @@ pure module subroutine forward(self, input)
4856 integer :: i
4957
5058 self % output(:,:) = matmul (input(:,:), self % weights)
51- do concurrent(i = 1 :self % sequence_length)
52- self % output(i,:) = self % output(i,:) + self % biases
53- end do
59+ if (self % use_biases) then
60+ do concurrent(i = 1 :self % sequence_length)
61+ self % output(i,:) = self % output(i,:) + self % biases
62+ end do
63+ end if
5464
5565 end subroutine forward
5666
@@ -64,7 +74,9 @@ pure module subroutine backward(self, input, gradient)
6474 integer :: i
6575
6676 self % dw = self % dw + matmul (transpose (input(:,:)), gradient(:,:))
67- self % db = self % db + sum (gradient(:,:), 1 )
77+ if (self % use_biases) then
78+ self % db = self % db + sum (gradient(:,:), 1 )
79+ end if
6880 self % gradient(:,:) = matmul (gradient(:,:), transpose (self % weights))
6981 end subroutine backward
7082
@@ -74,7 +86,10 @@ pure module function get_num_params(self) result(num_params)
7486 integer :: num_params
7587
7688 ! Number of weights times number of biases
77- num_params = self % in_features * self % out_features + self % out_features
89+ num_params = self % in_features * self % out_features
90+ if (self % use_biases) then
91+ num_params = num_params + self % out_features
92+ end if
7893
7994 end function get_num_params
8095
@@ -87,10 +102,14 @@ module function get_params(self) result(params)
87102
88103 w_(1 : product (shape (self % weights))) = > self % weights
89104
90- params = [ &
91- w_, &
92- self % biases &
93- ]
105+ if (self % use_biases) then
106+ params = [ &
107+ w_, &
108+ self % biases &
109+ ]
110+ else
111+ params = w_
112+ end if
94113
95114 end function get_params
96115
@@ -103,10 +122,14 @@ module function get_gradients(self) result(gradients)
103122
104123 dw_(1 : product (shape (self % dw))) = > self % dw
105124
106- gradients = [ &
107- dw_, &
108- self % db &
109- ]
125+ if (self % use_biases) then
126+ gradients = [ &
127+ dw_, &
128+ self % db &
129+ ]
130+ else
131+ gradients = dw_
132+ end if
110133
111134 end function get_gradients
112135
@@ -127,10 +150,12 @@ module subroutine set_params(self, params)
127150 p_(1 :self % in_features, 1 :self % out_features) = > params(1 : n)
128151 self % weights = p_
129152
130- ! reshape the biases
131- self % biases = params(n + 1 : n + self % out_features)
153+ if (self % use_biases) then
154+ ! reshape the biases
155+ self % biases = params(n + 1 : n + self % out_features)
156+ end if
132157 end associate
133158
134159 end subroutine set_params
135160
136- end submodule nf_linear2d_layer_submodule
161+ end submodule nf_linear2d_layer_submodule
0 commit comments