|
| 1 | +\titledquestion{Automatic Differentiation and Attention} |
| 2 | + |
| 3 | +Consider the matrices |
| 4 | +$V \in \mathbb{R}^{d_v \times n_\text{ctx}}, |
| 5 | +K \in \mathbb{R}^{d_k \times n_\text{ctx}}$ |
| 6 | +and an attention head applied to a vector $q$ : $a(q) = V\text{softmax}(K^\top q / \sqrt{d_k})$ |
| 7 | +where the $i$-th entry of $\text{softmax}(x)$ is $\exp(x_i) / \sum_j \exp(x_j)$. |
| 8 | +Suppose that you want to compute Jacobian $L$ of $a(q)$, which is the derivative of all outputs of the attention |
| 9 | +head with respect to each entry of the vector $q$. |
| 10 | +In other words, $L_{ij} = \partial a_i / \partial q_j$. |
| 11 | + |
| 12 | +\begin{itemize} |
| 13 | + \item Find the matrix $J$ to be used to compute |
| 14 | + the derivative of $a$ with respect to the $j$-th entry of $q$ |
| 15 | + with the formula |
| 16 | + $$\partial a / \partial q_j = V J K^\top e_j / \sqrt{d_k}.$$ |
| 17 | + \emph{Hint:} The entries of the matrix $J$ can be written purely in terms |
| 18 | + of the entries of the vector $s = \text{softmax}(K^\top q / \sqrt{d_k})$. |
| 19 | + \begin{solutionbox}{9cm} |
| 20 | + Let $x = K^\top q / \sqrt{d_k}$. We have |
| 21 | + \begin{align*} |
| 22 | + J_{ij} & = |
| 23 | + \partial s_i / \partial x_j\\ |
| 24 | + & = |
| 25 | + \frac{\partial}{\partial x_j} \frac{\exp(x_i)}{\sum_{k=1}^{n_{\text{ctx}}} \exp(x_k)}\\ |
| 26 | + & = |
| 27 | + \frac{\exp(x_i)}{\sum_{k=1}^{n_{\text{ctx}}} \exp(x_k)} |
| 28 | + \frac{\partial x_i}{\partial x_j} |
| 29 | + - |
| 30 | + \frac{\exp(x_i)}{(\sum_{j=1}^{n_{\text{ctx}}} \exp(x_j))^2} |
| 31 | + \frac{\partial}{\partial x_j} \sum_{k=1}^{n_{\text{ctx}}} \exp(x_k)\\ |
| 32 | + & = |
| 33 | + s_i |
| 34 | + \frac{\partial x_i}{\partial x_j} |
| 35 | + - s_is_j. |
| 36 | + \end{align*} |
| 37 | + So the entries of $J$ are: |
| 38 | + \begin{align*} |
| 39 | + J_{ii} & = s_i - s_i^2\\ |
| 40 | + J_{ij} & = -s_is_j & \quad i \neq j. |
| 41 | + \end{align*} |
| 42 | + \end{solutionbox} |
| 43 | + \item The previous question corresponds to \emph{forward} differentiation. |
| 44 | + Using \emph{reverse} differentiation, you would like now to |
| 45 | + compute the gradient of $a_i$ with respect to the vector $q$. |
| 46 | + How can you compute this gradient vector via matrix-vector products ? |
| 47 | + |
| 48 | + \emph{Hint:} $\partial a_i / \partial q_j$ is the scalar product between |
| 49 | + $\partial a / \partial q_j$ and $e_i$. |
| 50 | + \begin{solutionbox}{6cm} |
| 51 | + As $J$ is diagonal, it is its own transpose. |
| 52 | + \begin{align*} |
| 53 | + \partial a_i / \partial q_j |
| 54 | + & = |
| 55 | + \langle V J K^\top e_j / \sqrt{d_k}, e_i \rangle\\ |
| 56 | + & = |
| 57 | + \langle e_j, K J^\top V^\top e_i / \sqrt{d_k} \rangle\\ |
| 58 | + \partial a_i / \partial q |
| 59 | + & = |
| 60 | + K J V^\top e_i / \sqrt{d_k}\\ |
| 61 | + \end{align*} |
| 62 | + \end{solutionbox} |
| 63 | + \item |
| 64 | + Depending on $d_v, d_k, n_\text{ctx}$, which one will be faster between forward and reverse differentiation |
| 65 | + to compute the \textbf{full} Jacobian matrix $\partial a / \partial q$ ? Why ? |
| 66 | + \begin{solutionbox}{4cm} |
| 67 | + Forward diff concatenates $\partial a / \partial q_j$ horizontally for each $j$ and |
| 68 | + reverse diff concatenates $\partial a_i / \partial q$ |
| 69 | + vertically for each $i$. |
| 70 | + In other words, forward diff computes $V(JK^\top)$ while |
| 71 | + reverse diff computes $K(J^\top V^\top)$ or equivalently |
| 72 | + $(VJ)K^\top$. |
| 73 | + The complexity of forward diff is $O(n_\text{ctx}^2 d_k + n_\text{ctx} d_k d_v)$ |
| 74 | + while the complexity of reverse diff is |
| 75 | + $O(n_\text{ctx}^2 d_v + n_\text{ctx} d_k d_v)$. |
| 76 | + This means that forward diff is faster if $d_k < d_v$, otherwise reverse diff is faster. |
| 77 | + \end{solutionbox} |
| 78 | + \item How could this computation be accelerated using a GPU instead of a CPU ? |
| 79 | + \begin{solutionbox}{2cm} |
| 80 | + As these are matrix-matrix products, this computation is highly parallelizable |
| 81 | + and hence will get a good speed up on a GPU. |
| 82 | + \end{solutionbox} |
| 83 | +\end{itemize} |
0 commit comments