Skip to content

Commit 0a399cf

Browse files
committed
multihead_attention_optimization: move heads separation out of sdpa backward
1 parent a076246 commit 0a399cf

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/nf/nf_multihead_attention_submodule.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ pure module subroutine common_backward(self, input, gradient, attention_mask)
2121

2222
integer :: head, seq, i, j
2323

24+
self % v_heads = self % split_heads(self % value_layer % output)
25+
self % k_heads = self % split_heads(self % key_layer % output)
26+
self % q_heads = self % split_heads(self % query_layer % output)
27+
2428
! bakward through attention mechanism
2529
call self % sdpa_backward(gradient, attention_mask)
2630

@@ -80,9 +84,6 @@ pure module subroutine sdpa_backward(self, gradient, attention_mask)
8084

8185
! split heads from output gradient
8286
self % d_output = self % split_heads(self % output_layer % gradient)
83-
self % v_heads = self % split_heads(self % value_layer % output)
84-
self % k_heads = self % split_heads(self % key_layer % output)
85-
self % q_heads = self % split_heads(self % query_layer % output)
8687

8788
! iterate over heads to calculate deltas for each of them
8889
do concurrent(head = 1: self % n_heads)

0 commit comments

Comments
 (0)