1+
2+ #-------------------------------------------------------------
3+ #
4+ # Licensed to the Apache Software Foundation (ASF) under one
5+ # or more contributor license agreements. See the NOTICE file
6+ # distributed with this work for additional information
7+ # regarding copyright ownership. The ASF licenses this file
8+ # to you under the Apache License, Version 2.0 (the
9+ # "License"); you may not use this file except in compliance
10+ # with the License. You may obtain a copy of the License at
11+ #
12+ # http://www.apache.org/licenses/LICENSE-2.0
13+ #
14+ # Unless required by applicable law or agreed to in writing,
15+ # software distributed under the License is distributed on an
16+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17+ # KIND, either express or implied. See the License for the
18+ # specific language governing permissions and limitations
19+ # under the License.
20+ #
21+ #-------------------------------------------------------------
22+
23+ source("nn/layers/softmax.dml") as softmax
24+ source("nn/layers/dropout.dml") as dropout
25+ source("scripts/nn/util.dml") as util
26+
27+
28+ forward = function(matrix[double] Q, matrix[double] K,
29+ matrix[double] V, int H, int T, int D, double dropout_p)
30+ return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) {
31+ /*
32+ * Computes the forward pass for a multi-head attention layer.
33+ *
34+ * Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
35+ * - Q: Input querys, of shape (B,T*H*D).
36+ * - K: Input keys, of shape (B,T*H*D).
37+ * - V: Input values, of shape (B,T*H*D).
38+ * - H: Head count.
39+ * - T: Sequence length.
40+ * - D: Embedding length of single query, value, key,
41+ * - dropout_p: Dropout probability.
42+ *
43+ * Outputs:
44+ * - context: Token context embeddings, of shape (B, T*H*D)
45+ * - attention: Attention on value(s) for given query(s), of shape (B, H*T*T)
46+ * - dropout_mask: Dropout mask used on attention, of shape (B, H*T*T)
47+ */
48+ B = nrow(Q)
49+
50+ # Transpose head and token dimension for per-head computation
51+ Q = util::transpose_ABCD_to_ACBD(Q, T, H) # Shape (B, H*T*D)
52+ K = util::transpose_ABCD_to_ACBD(K, T, H) # Shape (B, H*T*D)
53+ V = util::transpose_ABCD_to_ACBD(V, T, H) # Shape (B, H*T*D)
54+
55+ attention = matrix(0, rows=B, cols=H*T*T)
56+ dropout_mask = matrix(0, rows=B, cols=H*T*T)
57+ context = matrix(0, rows=B, cols=H*T*D)
58+ K_norm = K / sqrt(D)
59+
60+ # For loops for tensor operations
61+ for (batch in 1:B) {
62+ attention_probs_b = matrix(0, rows=H, cols=T*T)
63+ if (dropout_p > 0.0) {
64+ dropout_mask_b = matrix(0, rows=H, cols=T*T)
65+ }
66+ context_b = matrix(0, rows=H, cols=T*D)
67+ Q_b = matrix(Q[batch], rows=H, cols=T*D)
68+ K_norm_b = matrix(K_norm[batch], rows=H, cols=T*D)
69+ V_b = matrix(V[batch], rows=H, cols=T*D)
70+
71+ for (head in 1:H) {
72+ Q_h = matrix(Q_b[head], rows=T, cols=D)
73+ K_norm_h = matrix(K_norm_b[head], rows=T, cols=D)
74+ V_h = matrix(V_b[head], rows=T, cols=D)
75+
76+ attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T)
77+
78+ # TODO: Add support for attention mask here
79+
80+ # Column-wise softmax
81+ attention_probs_h = softmax::forward(attention_scores)
82+
83+ if (dropout_p > 0.0) {
84+ [attention_probs_h, dropout_mask_h] = dropout::forward(attention_probs_h, dropout_p, -1)
85+ }
86+
87+ context_h = attention_probs_h %*% V_h # Shape (T, D)
88+
89+ attention_probs_b[head] = matrix(attention_probs_h, rows=1, cols=T*T)
90+ if (dropout_p > 0.0) {
91+ dropout_mask_b[head] = matrix(dropout_mask_h, rows=1, cols=T*T)
92+ }
93+ context_b[head] = matrix(context_h, rows=1, cols=T*D)
94+ }
95+
96+ attention[batch] = matrix(attention_probs_b, rows=1, cols=H*T*T)
97+ if (dropout_p > 0.0) {
98+ dropout_mask[batch] = matrix(dropout_mask_b, rows=1, cols=H*T*T)
99+ }
100+ context[batch] = matrix(context_b, rows=1, cols=H*T*D)
101+ }
102+
103+ # Swap head and token dimension for original shape
104+ context = util::transpose_ABCD_to_ACBD(context, H, T)
105+ }
106+
107+
108+ backward = function(matrix[double] dcontext,
109+ matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q,
110+ matrix[double] K, matrix[double] V, int H, int T,
111+ int D, double dropout_p)
112+ return (matrix[double] dQ, matrix[double] dK, matrix[double] dV) {
113+ /*
114+ * Computes the backward pass for a multi-head attention layer.
115+ *
116+ * Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
117+ * - dcontext: Gradient w.r.t. the context matrix of shape (B, T*H*D)
118+ * - dropout_mask: Dropout mask from forward pass of shape (B, H*T*T)
119+ * - attention: Attention output from forward pass of shape (B, H*T*T)
120+ * - Q: Input querys, of shape (B,T*H*D).
121+ * - K: Input keys, of shape (B,T*H*D).
122+ * - V: Input values, of shape (B,T*H*D).
123+ * - H: Head count.
124+ * - T: Sequence length.
125+ * - D: Embedding length of single query, value, key,
126+ * - dropout_p: Dropout probability.
127+ *
128+ * Outputs:
129+ * - dQ: Gradient w.r.t. input querys, of shape (B,T*H*D).
130+ * - dK: Gradient w.r.t. input keys, of shape (B,T*H*D).
131+ * - dV: Gradient w.r.t. input values, of shape (B,T*H*D).
132+ */
133+ B = nrow(Q)
134+
135+ # Transpose head and token dimension for per-head computation
136+ dcontext = util::transpose_ABCD_to_ACBD(dcontext, T, H) # Shape (B, H*T*D)
137+ Q = util::transpose_ABCD_to_ACBD(Q, T, H) # Shape (B, H*T*D)
138+ K = util::transpose_ABCD_to_ACBD(K, T, H) # Shape (B, H*T*D)
139+ V = util::transpose_ABCD_to_ACBD(V, T, H) # Shape (B, H*T*D)
140+
141+ dQ = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
142+ dK = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
143+ dV = matrix(0, rows=B, cols=H*T*D) # Shape (B, H*T*D)
144+
145+ K_norm = K / sqrt(D)
146+
147+ # For loops for tensor operations
148+ for (batch in 1:B) {
149+ dcontext_b = matrix(dcontext[batch], rows=H, cols=T*D)
150+ if (dropout_p > 0.0) {
151+ dropout_mask_b = matrix(dropout_mask[batch], rows=H, cols=T*T)
152+ }
153+ attention_b = matrix(attention[batch], rows=H, cols=T*T)
154+
155+ Q_b = matrix(Q[batch], rows=H, cols=T*D)
156+ K_norm_b = matrix(K_norm[batch], rows=H, cols=T*D)
157+ V_b = matrix(V[batch], rows=H, cols=T*D)
158+
159+ dQ_b = matrix(0, rows=H, cols=T*D)
160+ dK_b = matrix(0, rows=H, cols=T*D)
161+ dV_b = matrix(0, rows=H, cols=T*D)
162+
163+ for (head in 1:H) {
164+ dcontext_h = matrix(dcontext_b[head], rows=T, cols=D)
165+ if (dropout_p > 0.0) {
166+ dropout_mask_h = matrix(dropout_mask_b[head], rows=T, cols=T)
167+ }
168+ attention_h = matrix(attention_b[head], rows=T, cols=T)
169+
170+ # Compute dV early to release attention_h
171+ dV_h = t(attention_h) %*% dcontext_h
172+
173+ Q_h = matrix(Q_b[head], rows=T, cols=D)
174+ K_norm_h = matrix(K_norm_b[head], rows=T, cols=D)
175+ V_h = matrix(V_b[head], rows=T, cols=D)
176+
177+ dattention_probs = dcontext_h %*% t(V_h)
178+
179+ if (dropout_p > 0.0) {
180+ # Provide unnecessary required X input matrix via empty matrix
181+ dattention_probs = dropout::backward(dattention_probs, matrix(0, rows=1, cols=1), dropout_p, dropout_mask_h)
182+ }
183+ attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T)
184+ dattention_scores = softmax::backward(dattention_probs, attention_scores)
185+
186+ dQ_h = dattention_scores %*% K_norm_h
187+ dK_h = t(dattention_scores) %*% (Q_h / sqrt(D))
188+
189+ # Append to batch matrices
190+ dK_b[head] = matrix(dK_h, rows=1, cols=T*D)
191+ dQ_b[head] = matrix(dQ_h, rows=1, cols=T*D)
192+ dV_b[head] = matrix(dV_h, rows=1, cols=T*D)
193+ }
194+
195+ # Append to output matrices
196+ dK[batch] = matrix(dK_b, rows=1, cols=H*T*D)
197+ dQ[batch] = matrix(dQ_b, rows=1, cols=H*T*D)
198+ dV[batch] = matrix(dV_b, rows=1, cols=H*T*D)
199+ }
200+
201+ # Swap head and token dimensions
202+ dK = util::transpose_ABCD_to_ACBD(dK, H, T)
203+ dQ = util::transpose_ABCD_to_ACBD(dQ, H, T)
204+ dV = util::transpose_ABCD_to_ACBD(dV, H, T)
205+ }
0 commit comments