Skip to content

Commit e5072d3

Browse files
committed
Update optimizer flow for layernorm
1 parent ad176ea commit e5072d3

File tree

4 files changed

+61
-23
lines changed

4 files changed

+61
-23
lines changed

src/nf/nf_layernorm.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ module nf_layernorm_layer
3838
procedure :: init
3939
procedure :: get_num_params
4040
procedure :: get_params
41+
procedure :: get_params_ptr
4142
procedure :: get_gradients
43+
procedure :: get_gradients_ptr
4244
procedure :: set_params
4345
end type layernorm_layer
4446

@@ -78,12 +80,24 @@ module function get_params(self) result(params)
7880
end function get_params
7981

8082

83+
module subroutine get_params_ptr(self, g_ptr, b_ptr)
84+
class(layernorm_layer), intent(in), target :: self
85+
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
86+
end subroutine get_params_ptr
87+
88+
8189
module function get_gradients(self) result(gradients)
8290
class(layernorm_layer), intent(in), target :: self
8391
real, allocatable :: gradients(:)
8492
end function get_gradients
8593

8694

95+
module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
96+
class(layernorm_layer), intent(in), target :: self
97+
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
98+
end subroutine get_gradients_ptr
99+
100+
87101
module subroutine set_params(self, params)
88102
class(layernorm_layer), intent(in out) :: self
89103
real, intent(in), target :: params(:)

src/nf/nf_layernorm_submodule.f90

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,31 @@ end function get_num_params
112112
module function get_params(self) result(params)
113113
class(layernorm_layer), intent(in), target :: self
114114
real, allocatable :: params(:)
115+
params = [self % gamma, self % beta]
116+
end function get_params
115117

116-
params = [ &
117-
self % gamma, &
118-
self % beta &
119-
]
120118

121-
end function get_params
119+
module subroutine get_params_ptr(self, g_ptr, b_ptr)
120+
class(layernorm_layer), intent(in), target :: self
121+
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
122+
g_ptr => self % gamma
123+
b_ptr => self % beta
124+
end subroutine get_params_ptr
122125

123126

124127
module function get_gradients(self) result(gradients)
125128
class(layernorm_layer), intent(in), target :: self
126129
real, allocatable :: gradients(:)
130+
gradients = [self % d_gamma, self % d_beta]
131+
end function get_gradients
127132

128-
gradients = [ &
129-
self % d_gamma, &
130-
self % d_beta &
131-
]
132133

133-
end function get_gradients
134+
module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
135+
class(layernorm_layer), intent(in), target :: self
136+
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
137+
dg_ptr => self % d_gamma
138+
db_ptr => self % d_beta
139+
end subroutine get_gradients_ptr
134140

135141

136142
module subroutine set_params(self, params)

src/nf/nf_network_submodule.f90

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,20 @@ module subroutine update(self, optimizer, batch_size)
750750
call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
751751
this_layer % dw = 0
752752
this_layer % db = 0
753+
type is(linear2d_layer)
754+
call this_layer % get_params_ptr(weights, biases)
755+
call this_layer % get_gradients_ptr(dw, db)
756+
call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
757+
call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
758+
this_layer % dw = 0
759+
this_layer % db = 0
760+
type is(layernorm_layer)
761+
call this_layer % get_params_ptr(weights, biases)
762+
call this_layer % get_gradients_ptr(dw, db)
763+
call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
764+
call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
765+
this_layer % d_gamma = 0
766+
this_layer % d_beta = 0
753767
end select
754768
end do
755769

test/test_layernorm.f90

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ program test_layernorm_instance
2727
end if
2828

2929
contains
30-
function allclose(x, y) result(res)
31-
real, intent(in) :: x(:)
32-
real, intent(in) :: y(:)
33-
logical :: res
3430

35-
res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y)))
31+
logical function allclose(x, y) result(res)
32+
real, intent(in) :: x(:), y(:)
33+
!res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y)))
34+
res = all(abs(x - y) <= 1e-05)
3635
end function allclose
3736

37+
3838
subroutine test_layernorm_forward(layernorm_instance, input, ok)
3939
type(layernorm_layer), intent(in out) :: layernorm_instance
4040
real, intent(in out) :: input(:, :)
@@ -61,6 +61,7 @@ subroutine test_layernorm_forward(layernorm_instance, input, ok)
6161
end if
6262
end subroutine test_layernorm_forward
6363

64+
6465
subroutine test_layernorm_backward(layernorm_instance, input, gradient, ok)
6566
type(layernorm_layer), intent(in out) :: layernorm_instance
6667
real, intent(in out) :: input(:, :)
@@ -103,6 +104,7 @@ subroutine test_layernorm_backward(layernorm_instance, input, gradient, ok)
103104
end if
104105
end subroutine test_layernorm_backward
105106

107+
106108
subroutine test_layernorm_gradients(input, gradient, ok)
107109
real, intent(in out) :: input(:, :)
108110
real, intent(in out) :: gradient(:, :)
@@ -152,6 +154,7 @@ subroutine test_layernorm_gradients(input, gradient, ok)
152154
end if
153155
end subroutine test_layernorm_gradients
154156

157+
155158
subroutine test_layernorm_integration(ok)
156159
logical, intent(in out) :: ok
157160

@@ -160,13 +163,13 @@ subroutine test_layernorm_integration(ok)
160163
real :: y(6) = [0.7, 0.2, 0.1, 0.1, 0.01, 0.9]
161164
real :: tolerance = 0.1
162165
integer :: epoch
163-
integer :: epochs = 10000
166+
integer, parameter :: num_epochs = 100000
164167

165-
net = network([&
166-
input(2, 3),&
167-
linear2d(3),&
168-
layernorm(),&
169-
flatten()&
168+
net = network([ &
169+
input(2, 3), &
170+
linear2d(3), &
171+
layernorm(), &
172+
flatten() &
170173
])
171174

172175
! Kaiming weights to achieve semblance of convergance
@@ -177,17 +180,18 @@ subroutine test_layernorm_integration(ok)
177180
l % biases = 0.2
178181
end select
179182

180-
do epoch = 1, epochs
183+
do epoch = 1, num_epochs
181184
call net % forward(x)
182185
call net % backward(y)
183186
call net % update(optimizer=sgd(learning_rate=0.001))
184187
if (all(abs(net % predict(x) - y) < tolerance)) exit
185188
end do
186189

187-
if (.not. epoch <= epochs) then
190+
if (.not. epoch <= num_epochs) then
188191
write(stderr, '(a)') &
189192
'linear2d + layernorm should converge in simple training.. failed'
190193
ok = .false.
191194
end if
192195
end subroutine test_layernorm_integration
196+
193197
end program test_layernorm_instance

0 commit comments

Comments
 (0)