Skip to content

Commit aa59523

Browse files
committed
multihead_attention_optimization: refactoring, split methods even more (will be needed for llama attention)
1 parent 7d1a10d commit aa59523

File tree

2 files changed

+68
-37
lines changed

2 files changed

+68
-37
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ module nf_multihead_attention_layer
5656

5757
procedure :: common_backward
5858
procedure :: common_forward
59+
procedure :: sdpa_forward
60+
procedure :: sdpa_backward
5961
procedure :: get_num_params
6062
procedure :: get_params
6163
procedure :: get_gradients
@@ -102,6 +104,17 @@ pure module subroutine common_forward(self, query, key, value, attention_mask)
102104
real, optional, intent(in) :: attention_mask(:, :)
103105
end subroutine common_forward
104106

107+
pure module subroutine sdpa_forward(self, attention_mask)
108+
class(multihead_attention_layer), intent(in out) :: self
109+
real, intent(in), optional :: attention_mask(:, :)
110+
end subroutine sdpa_forward
111+
112+
pure module subroutine sdpa_backward(self, gradient, attention_mask)
113+
class(multihead_attention_layer), intent(in out) :: self
114+
real, intent(in) :: gradient(:, :)
115+
real, intent(in), optional :: attention_mask(:, :)
116+
end subroutine sdpa_backward
117+
105118
pure module subroutine init(self, input_shape)
106119
!! Initialize the layer data structures.
107120
!!

src/nf/nf_multihead_attention_submodule.f90

Lines changed: 55 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,60 @@ pure module subroutine common_backward(self, input, gradient, attention_mask)
2121

2222
integer :: head, seq, i, j
2323

24+
! bakward through attention mechanism
25+
call self % sdpa_backward(gradient, attention_mask)
26+
27+
! calculate deltas for input layers
28+
call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv))
29+
call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk))
30+
call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq))
31+
end subroutine common_backward
32+
33+
pure module subroutine common_forward(self, query, key, value, attention_mask)
34+
class(multihead_attention_layer), intent(in out) :: self
35+
real, intent(in) :: query(:, :), key(:, :), value(:, :)
36+
real, intent(in), optional :: attention_mask(:, :)
37+
38+
self % q_input = query
39+
self % k_input = key
40+
self % v_input = value
41+
42+
! run inputs through linear layers (trainable params)
43+
call self % query_layer % forward(query)
44+
call self % key_layer % forward(key)
45+
call self % value_layer % forward(value)
46+
47+
! split attention heads for more efficient computation
48+
self % q_or_dq = self % split_heads(self % query_layer % output)
49+
self % k_or_dk = self % split_heads(self % key_layer % output)
50+
self % v_or_dv = self % split_heads(self % value_layer % output)
51+
52+
call self % sdpa_forward(attention_mask)
53+
end subroutine common_forward
54+
55+
pure module subroutine sdpa_forward(self, attention_mask)
56+
class(multihead_attention_layer), intent(in out) :: self
57+
real, intent(in), optional :: attention_mask(:, :)
58+
59+
! create key by value matrix
60+
call self % create_attention_matrix(self % q_or_dq, self % k_or_dk)
61+
! apply softmax and scaling
62+
call self % normalize_attention_matrix(attention_mask)
63+
! multiply attention matrix by value
64+
call self % scaled_dot_product_attention(self % v_or_dv)
65+
66+
self % o_input = self % combine_heads(self % sdpa)
67+
call self % output_layer % forward(self % o_input)
68+
self % output = self % output_layer % output
69+
end subroutine sdpa_forward
70+
71+
pure module subroutine sdpa_backward(self, gradient, attention_mask)
72+
class(multihead_attention_layer), intent(in out) :: self
73+
real, intent(in) :: gradient(:, :)
74+
real, intent(in), optional :: attention_mask(:, :)
75+
76+
integer :: head, seq, i, j
77+
2478
! calculate output layer delta
2579
call self % output_layer % backward(self % o_input, gradient)
2680

@@ -78,43 +132,7 @@ pure module subroutine common_backward(self, input, gradient, attention_mask)
78132
! calculate delta for key, attention matrix should be transposed unlike for query
79133
self % k_or_dk(:, :, head) = matmul(transpose(self % d_normalize(:, :, head)), self % q_heads(:, :, head))
80134
end do
81-
82-
! calculate deltas for input layers
83-
call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv))
84-
call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk))
85-
call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq))
86-
end subroutine common_backward
87-
88-
pure module subroutine common_forward(self, query, key, value, attention_mask)
89-
class(multihead_attention_layer), intent(in out) :: self
90-
real, intent(in) :: query(:, :), key(:, :), value(:, :)
91-
real, intent(in), optional :: attention_mask(:, :)
92-
93-
self % q_input = query
94-
self % k_input = key
95-
self % v_input = value
96-
97-
! run inputs through linear layers (trainable params)
98-
call self % query_layer % forward(query)
99-
call self % key_layer % forward(key)
100-
call self % value_layer % forward(value)
101-
102-
! split attention heads for more efficient computation
103-
self % q_or_dq = self % split_heads(self % query_layer % output)
104-
self % k_or_dk = self % split_heads(self % key_layer % output)
105-
self % v_or_dv = self % split_heads(self % value_layer % output)
106-
107-
! create key by value matrix
108-
call self % create_attention_matrix(self % q_or_dq, self % k_or_dk)
109-
! apply softmax and scaling
110-
call self % normalize_attention_matrix(attention_mask)
111-
! multiply attention matrix by value
112-
call self % scaled_dot_product_attention(self % v_or_dv)
113-
114-
self % o_input = self % combine_heads(self % sdpa)
115-
call self % output_layer % forward(self % o_input)
116-
self % output = self % output_layer % output
117-
end subroutine common_forward
135+
end subroutine sdpa_backward
118136

119137
pure module function split_heads(self, input) result(output)
120138
class(multihead_attention_layer), intent(in) :: self

0 commit comments

Comments
 (0)