@@ -21,6 +21,60 @@ pure module subroutine common_backward(self, input, gradient, attention_mask)
2121
2222 integer :: head, seq, i, j
2323
24+ ! bakward through attention mechanism
25+ call self % sdpa_backward(gradient, attention_mask)
26+
27+ ! calculate deltas for input layers
28+ call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv))
29+ call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk))
30+ call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq))
31+ end subroutine common_backward
32+
33+ pure module subroutine common_forward(self, query, key, value, attention_mask)
34+ class(multihead_attention_layer), intent (in out ) :: self
35+ real , intent (in ) :: query(:, :), key(:, :), value(:, :)
36+ real , intent (in ), optional :: attention_mask(:, :)
37+
38+ self % q_input = query
39+ self % k_input = key
40+ self % v_input = value
41+
42+ ! run inputs through linear layers (trainable params)
43+ call self % query_layer % forward(query)
44+ call self % key_layer % forward(key)
45+ call self % value_layer % forward(value)
46+
47+ ! split attention heads for more efficient computation
48+ self % q_or_dq = self % split_heads(self % query_layer % output)
49+ self % k_or_dk = self % split_heads(self % key_layer % output)
50+ self % v_or_dv = self % split_heads(self % value_layer % output)
51+
52+ call self % sdpa_forward(attention_mask)
53+ end subroutine common_forward
54+
55+ pure module subroutine sdpa_forward(self, attention_mask)
56+ class(multihead_attention_layer), intent (in out ) :: self
57+ real , intent (in ), optional :: attention_mask(:, :)
58+
59+ ! create key by value matrix
60+ call self % create_attention_matrix(self % q_or_dq, self % k_or_dk)
61+ ! apply softmax and scaling
62+ call self % normalize_attention_matrix(attention_mask)
63+ ! multiply attention matrix by value
64+ call self % scaled_dot_product_attention(self % v_or_dv)
65+
66+ self % o_input = self % combine_heads(self % sdpa)
67+ call self % output_layer % forward(self % o_input)
68+ self % output = self % output_layer % output
69+ end subroutine sdpa_forward
70+
71+ pure module subroutine sdpa_backward(self, gradient, attention_mask)
72+ class(multihead_attention_layer), intent (in out ) :: self
73+ real , intent (in ) :: gradient(:, :)
74+ real , intent (in ), optional :: attention_mask(:, :)
75+
76+ integer :: head, seq, i, j
77+
2478 ! calculate output layer delta
2579 call self % output_layer % backward(self % o_input, gradient)
2680
@@ -78,43 +132,7 @@ pure module subroutine common_backward(self, input, gradient, attention_mask)
78132 ! calculate delta for key, attention matrix should be transposed unlike for query
79133 self % k_or_dk(:, :, head) = matmul (transpose (self % d_normalize(:, :, head)), self % q_heads(:, :, head))
80134 end do
81-
82- ! calculate deltas for input layers
83- call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv))
84- call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk))
85- call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq))
86- end subroutine common_backward
87-
88- pure module subroutine common_forward(self, query, key, value, attention_mask)
89- class(multihead_attention_layer), intent (in out ) :: self
90- real , intent (in ) :: query(:, :), key(:, :), value(:, :)
91- real , intent (in ), optional :: attention_mask(:, :)
92-
93- self % q_input = query
94- self % k_input = key
95- self % v_input = value
96-
97- ! run inputs through linear layers (trainable params)
98- call self % query_layer % forward(query)
99- call self % key_layer % forward(key)
100- call self % value_layer % forward(value)
101-
102- ! split attention heads for more efficient computation
103- self % q_or_dq = self % split_heads(self % query_layer % output)
104- self % k_or_dk = self % split_heads(self % key_layer % output)
105- self % v_or_dv = self % split_heads(self % value_layer % output)
106-
107- ! create key by value matrix
108- call self % create_attention_matrix(self % q_or_dq, self % k_or_dk)
109- ! apply softmax and scaling
110- call self % normalize_attention_matrix(attention_mask)
111- ! multiply attention matrix by value
112- call self % scaled_dot_product_attention(self % v_or_dv)
113-
114- self % o_input = self % combine_heads(self % sdpa)
115- call self % output_layer % forward(self % o_input)
116- self % output = self % output_layer % output
117- end subroutine common_forward
135+ end subroutine sdpa_backward
118136
119137 pure module function split_heads(self, input) result(output)
120138 class(multihead_attention_layer), intent (in ) :: self
0 commit comments