1616 use nf_reshape_layer, only: reshape3d_layer
1717 use nf_linear2d_layer, only: linear2d_layer
1818 use nf_self_attention_layer, only: self_attention_layer
19+ use nf_layernorm_layer, only: layernorm_layer
1920 use nf_optimizers, only: optimizer_base_type
2021
2122contains
@@ -49,7 +50,6 @@ pure module subroutine backward_1d(self, previous, gradient)
4950 call this_layer % backward(gradient)
5051
5152 type is (flatten_layer)
52-
5353 ! Upstream layers permitted: input2d, input3d, conv1d, conv2d, locally_connected_1d, maxpool1d, maxpool2d
5454 select type (prev_layer = > previous % p)
5555 type is (input2d_layer)
@@ -70,6 +70,8 @@ pure module subroutine backward_1d(self, previous, gradient)
7070 call this_layer % backward(prev_layer % output, gradient)
7171 type is (self_attention_layer)
7272 call this_layer % backward(prev_layer % output, gradient)
73+ type is (layernorm_layer)
74+ call this_layer % backward(prev_layer % output, gradient)
7375 end select
7476
7577 end select
@@ -94,6 +96,8 @@ pure module subroutine backward_2d(self, previous, gradient)
9496 call this_layer % backward(prev_layer % output, gradient)
9597 type is (self_attention_layer)
9698 call this_layer % backward(prev_layer % output, gradient)
99+ type is (layernorm_layer)
100+ call this_layer % backward(prev_layer % output, gradient)
97101 end select
98102
99103 type is (self_attention_layer)
@@ -105,8 +109,18 @@ pure module subroutine backward_2d(self, previous, gradient)
105109 call this_layer % backward(prev_layer % output, gradient)
106110 type is (self_attention_layer)
107111 call this_layer % backward(prev_layer % output, gradient)
112+ type is (layernorm_layer)
113+ call this_layer % backward(prev_layer % output, gradient)
108114 end select
109115
116+ type is (layernorm_layer)
117+
118+ select type (prev_layer = > previous % p)
119+ type is (linear2d_layer)
120+ call this_layer % backward(prev_layer % output, gradient)
121+ type is (self_attention_layer)
122+ call this_layer % backward(prev_layer % output, gradient)
123+ end select
110124 end select
111125
112126 ! Backward pass from a 2-d layer downstream currently implemented
@@ -358,6 +372,8 @@ module subroutine forward(self, input)
358372 call this_layer % forward(prev_layer % output)
359373 type is (linear2d_layer)
360374 call this_layer % forward(prev_layer % output)
375+ type is (layernorm_layer)
376+ call this_layer % forward(prev_layer % output)
361377 end select
362378
363379 type is (reshape3d_layer)
@@ -380,26 +396,40 @@ module subroutine forward(self, input)
380396
381397 type is (linear2d_layer)
382398
383- ! Upstream layers permitted: input2d, linear2d
399+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
384400 select type (prev_layer = > input % p)
385401 type is (input2d_layer)
386402 call this_layer % forward(prev_layer % output)
387403 type is (linear2d_layer)
388404 call this_layer % forward(prev_layer % output)
389405 type is (self_attention_layer)
390406 call this_layer % forward(prev_layer % output)
407+ type is (layernorm_layer)
408+ call this_layer % forward(prev_layer % output)
391409 end select
392410
393411 type is (self_attention_layer)
394412
395- ! Upstream layers permitted: input2d, linear2d
413+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
396414 select type (prev_layer = > input % p)
397415 type is (input2d_layer)
398416 call this_layer % forward(prev_layer % output)
399417 type is (linear2d_layer)
400418 call this_layer % forward(prev_layer % output)
401419 type is (self_attention_layer)
402420 call this_layer % forward(prev_layer % output)
421+ type is (layernorm_layer)
422+ call this_layer % forward(prev_layer % output)
423+ end select
424+
425+ type is (layernorm_layer)
426+
427+ ! Upstream layers permitted: linear2d, self_attention
428+ select type (prev_layer = > input % p)
429+ type is (linear2d_layer)
430+ call this_layer % forward(prev_layer % output)
431+ type is (self_attention_layer)
432+ call this_layer % forward(prev_layer % output)
403433 end select
404434
405435 end select
@@ -449,6 +479,8 @@ pure module subroutine get_output_2d(self, output)
449479 allocate (output, source= this_layer % output)
450480 type is (self_attention_layer)
451481 allocate (output, source= this_layer % output)
482+ type is (layernorm_layer)
483+ allocate (output, source= this_layer % output)
452484 class default
453485 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
454486
@@ -492,8 +524,8 @@ impure elemental module subroutine init(self, input)
492524 call this_layer % init(input % layer_shape)
493525 end select
494526
495- ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
496- ! self_attention layers is not known until we receive an input layer.
527+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
528+ ! self_attention or layernorm layers is not known until we receive an input layer.
497529 select type (this_layer = > self % p)
498530 type is (conv1d_layer)
499531 self % layer_shape = shape (this_layer % output)
@@ -511,6 +543,8 @@ impure elemental module subroutine init(self, input)
511543 self % layer_shape = shape (this_layer % output)
512544 type is (self_attention_layer)
513545 self % layer_shape = shape (this_layer % output)
546+ type is (layernorm_layer)
547+ self % layer_shape = shape (this_layer % output)
514548 type is (maxpool2d_layer)
515549 self % layer_shape = shape (this_layer % output)
516550 end select
@@ -577,6 +611,8 @@ elemental module function get_num_params(self) result(num_params)
577611 num_params = this_layer % get_num_params()
578612 type is (self_attention_layer)
579613 num_params = this_layer % get_num_params()
614+ type is (layernorm_layer)
615+ num_params = this_layer % get_num_params()
580616 class default
581617 error stop ' Unknown layer type.'
582618 end select
@@ -618,6 +654,8 @@ module function get_params(self) result(params)
618654 params = this_layer % get_params()
619655 type is (self_attention_layer)
620656 params = this_layer % get_params()
657+ type is (layernorm_layer)
658+ params = this_layer % get_params()
621659 class default
622660 error stop ' Unknown layer type.'
623661 end select
@@ -659,6 +697,8 @@ module function get_gradients(self) result(gradients)
659697 gradients = this_layer % get_gradients()
660698 type is (self_attention_layer)
661699 gradients = this_layer % get_gradients()
700+ type is (layernorm_layer)
701+ gradients = this_layer % get_gradients()
662702 class default
663703 error stop ' Unknown layer type.'
664704 end select
@@ -728,6 +768,9 @@ module subroutine set_params(self, params)
728768 type is (self_attention_layer)
729769 call this_layer % set_params(params)
730770
771+ type is (layernorm_layer)
772+ call this_layer % set_params(params)
773+
731774 type is (maxpool2d_layer)
732775 ! No parameters to set.
733776 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments