You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _config.yml
+2-1Lines changed: 2 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -357,6 +357,7 @@ team:
357
357
358
358
- full_name: Siavash Golkar
359
359
avatar: siavash_golkar.jpg
360
+
website: http://siavashgolkar.com/
360
361
bio: Siavash Golkar is a research scientist working in the field of machine learning and applications to scientific domains. Siavash received his PhD in 2015 in the field of High Energy Physics, working on topological states of matter. He has since worked as a postdoc at Cambridge University and New York University and as an associate research scientist at the Flatiron Institute. His recent work spans research in ML from continual and transfer learning to applying large transformer models to numerical and scientific datasets.
bio: Bruno Régaldo-Saint Blancard is a Research Fellow at the Center for Computational Mathematics, Flatiron Institute. He obtained a PhD in Astrophysics from the École Normale Supérieure (ENS), Paris. Prior to that, he graduated from the École Polytechnique, and obtained a M.S. in Astrophysics from the Observatoire de Paris. Bruno’s research focuses on the development of statistical methods for astrophysics/cosmology and beyond, using signal processing and machine learning. He is interested in various problems including generative modeling, inference, denoising, and source separation.
Copy file name to clipboardExpand all lines: _posts/2024-05-30-counting.md
+26-15Lines changed: 26 additions & 15 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,7 +2,7 @@
2
2
layout: post
3
3
title: "How Do Transformers Count in Context?"
4
4
authors: Siavash Golkar, Alberto Bietti, Mariel Pettee, Michael Eickenberg, Miles Cranmer, Keiya Hirashima, Geraud Krawezik, Nicholas Lourie, Michael McCabe, Rudy Morel, Ruben Ohana, Liam Holden Parker, Bruno Régaldo-Saint Blancard, Kyunghyun Cho, Shirley Ho
5
-
shorttitle: "Counting in Context?"
5
+
shorttitle: "Counting in Context"
6
6
date: 2024-05-30 09:23
7
7
image: counting-splash.jpg
8
8
smallimage: counting-s.jpg
@@ -24,13 +24,13 @@ To understand how Transformers solve complex problems, it helps to start with si
24
24
In this task, the input is a sequence composed of zeros, ones, and square bracket delimiters: `{0, 1, [, ]}`. Each sample sequence contains ones and zeros with several regions marked by the delimiters. The task is to count the number of ones within each delimited region. For example, given the sequence:
25
25
26
26
<palign="center">
27
-
<imgsrc="/images/blog/counting/input.png"alt="Example of input."width="52%"style="mix-blend-mode: darken;">
27
+
<imgclass="fullwidth"src="/images/blog/counting/input.png"alt="Example of input."width="52%"style="mix-blend-mode: darken;">
28
28
</p>
29
29
For simplicity, the regions are non-overlapping.
30
30
31
31
To tackle this task using a Transformer architecture, we use an encoder-decoder setup and provide the decoder with a fixed prompt that labels the regions. For the example above, the prompt would be:
32
32
<palign="center">
33
-
<imgsrc="/images/blog/counting/prompt.png"alt="Example of prompt."width="52%"style="mix-blend-mode: darken;">
33
+
<imgclass="fullwidth"src="/images/blog/counting/prompt.png"alt="Example of prompt."width="52%"style="mix-blend-mode: darken;">
34
34
</p>
35
35
For our experiments, we fix the number of regions to 4 and the sequence length to 512. This allows us to explore how solutions generalize to different numbers of regions and sequence lengths.
36
36
@@ -44,14 +44,14 @@ We provide some theoretical insights into the problem, showing that a Transforme
44
44
45
45
Contextual position refers to positional information in a sequence that is meaningful only within the context of the problem. For the contextual counting task, this means knowing the region number for each token. For example, with three regions, the input and contextual positions might look like:
46
46
<palign="center">
47
-
<imgsrc="/images/blog/counting/CP.png"alt="Example of contextual position."width="52%"style="mix-blend-mode: darken;">
47
+
<imgclass="fullwidth"src="/images/blog/counting/CP.png"alt="Example of contextual position."width="52%"style="mix-blend-mode: darken;">
48
48
</p>
49
49
This information helps disambiguate the different regions based on context.
50
50
51
51
#### Key Propositions
52
52
53
-
1.**Proposition 1:** If the regional contextual position information is available in the latent representation of the tokens at some layer of a Transformer, the contextual counting task can be solved with a single additional layer.
54
-
2.**Proposition 2:** A causal Transformer with a single layer and no position encoding (NoPE) can infer the regional contextual position.
53
+
-**Proposition 1:** If the regional contextual position information is available in the latent representation of the tokens at some layer of a Transformer, the contextual counting task can be solved with a single additional layer.
54
+
-**Proposition 2:** A causal Transformer with a single layer and no position encoding (NoPE) can infer the regional contextual position.
55
55
56
56
These propositions imply that a two-layer causal Transformer with NoPE can solve the contextual counting task.
57
57
@@ -71,7 +71,7 @@ These propositions highlight the difficulties non-causal Transformers face in so
71
71
The theoretical results above imply that exact solutions exist but do not clarify whether or not such solutions can indeed be found when the model is trained via SGD. We therefore trained various Transformer architectures on this task. Inspired by the theoretical arguments, we use an encoder-decoder architecture, with one layer and one head for each. A typical output of the network is shown in the following image where the model outputs the probability distribution over the number of ones in each region.
72
72
73
73
<palign="center">
74
-
<imgsrc="/images/blog/counting/output.png"alt="Typical output of the model"width="55%"style="mix-blend-mode: darken;">
74
+
<imgclass="fullwidth"src="/images/blog/counting/output.png"alt="Typical output of the model"width="55%"style="mix-blend-mode: darken;">
75
75
</p>
76
76
77
77
@@ -80,7 +80,7 @@ We summarize the results of this empirical exploration below.
<imgsrc="/images/blog/counting/accuracy.png"alt="Performance of the different configuration"width="65%"style="mix-blend-mode: darken;">
83
+
<imgclass="fullwidth"src="/images/blog/counting/accuracy.png"alt="Performance of the different configuration"width="65%"style="mix-blend-mode: darken;">
84
84
</p>
85
85
86
86
The above figure shows the performance of different Transformer configurations. The most prominent feature of this figure is that non-causal Transformers with any positional encoding fail to get good performance. In contrast, causal Transformers can achieve close to 100\% accuracy.
@@ -95,7 +95,7 @@ We also see that the very best model is trained with NoPE but RoPE is much more
95
95
As described above, the regional contextual position is an important piece of information for this task. Looking at the projection of the 1-token embeddings in the different regions, we see that this information is accurately captured.
96
96
97
97
<palign="center">
98
-
<imgsrc="/images/blog/counting/pca_proj.png"alt="PCA projection of the 1-tokens after the encoder layer."width="45%"style="mix-blend-mode: darken;">
98
+
<imgclass="fullwidth"src="/images/blog/counting/pca_proj.png"alt="PCA projection of the 1-tokens after the encoder layer."width="45%"style="mix-blend-mode: darken;">
99
99
</p>
100
100
By looking at the details of the attention module of the encoder, we see that in causal models, this information is inferred by attending to all the previous delimiter tokens equally. Each token can tell which region it is in by looking at how many delimiter tokens of each kind preceded it.
101
101
@@ -105,7 +105,7 @@ By looking at the details of the attention module of the encoder, we see that in
105
105
We can verify explicitly that the inferred regional contextual position in the encoder is used in the decoder cross-attention module such that the attention profile is focused on the 1-tokens of the relevant region (in the below figure, the third region).
106
106
107
107
<palign="center">
108
-
<imgsrc="/images/blog/counting/decoder.png"alt="The attention profile of the decoder."width="55%"style="mix-blend-mode: darken;">
108
+
<imgclass="fullwidth"src="/images/blog/counting/decoder.png"alt="The attention profile of the decoder."width="55%"style="mix-blend-mode: darken;">
109
109
</p>
110
110
111
111
We see that in this example, the decoder also attends to the beginning-of-sequence token. The reason for this is that, if the model *only* attends to the 1-tokens, then the number of the 1-tokens - the quantity of interest - is going to cancel in the calculation of the softmax. However, if there is another token, then the number of 1-tokens will be preserved. In this way, this other token acts as a bias term when computing the output of the attention module.
@@ -115,15 +115,15 @@ We see that in this example, the decoder also attends to the beginning-of-sequen
115
115
The figure below shows the behavior of three different type of solutions when generalizing to sequences of different lengths and inputs with different number of regions. Even though all three attain the same performance on the in-distribution data, their out-of-distribution performance is very different. Why is this the case?
116
116
117
117
<palign="center">
118
-
<imgsrc="/images/blog/counting/var_sols.png"alt="Different types of solutions."width="95%"style="mix-blend-mode: darken;">
118
+
<imgclass="fullwidth"src="/images/blog/counting/var_sols.png"alt="Different types of solutions."width="95%"style="mix-blend-mode: darken;">
119
119
</p>
120
120
121
121
We can get a hint at what might be the culprit by looking at the attention pattern of the decoder. The attention pattern given in the previous point pertains to the blue dots on this figure, i.e. the model that generalizes best.
122
122
123
123
The figure below, shows the attention pattern of the orange dots, i.e. the model that generalizes do different sequence lengths but not to different region numbers. We see that as before, the decoder pays attention to the 1-tokens of the relevant region (in this case the first region), however this time the role of the bias term is played by the ]-tokens. During training, the number of regions is fixed at 4, and therefore the number of ]-tokens can be used as a constant bias. However, this is not the case when the number of regions changes. This explains why this model does not generalize to other number of regions.
124
124
125
125
<palign="center">
126
-
<imgsrc="/images/blog/counting/decoder_nongen.png"alt="The attention profile of the decoder of a non-generalizing model."width="55%"style="mix-blend-mode: darken;">
126
+
<imgclass="fullwidth"src="/images/blog/counting/decoder_nongen.png"alt="The attention profile of the decoder of a non-generalizing model."width="55%"style="mix-blend-mode: darken;">
127
127
</p>
128
128
129
129
In our exploration, we found that the model can use any combination of quantities that are constant during training as biases.
@@ -135,13 +135,13 @@ In some of our experiments, we chose to remove the MLP and self-attention layers
135
135
In a previous case we saw that the decoder only attended to the 1-tokens of the relevant region and the beginning-of-sequence token. The figure below shows the value vectors of these two tokens.
136
136
137
137
<palign="center">
138
-
<imgsrc="/images/blog/counting/values.png"alt="The value vectors."width="55%"style="mix-blend-mode: darken;">
138
+
<imgclass="fullwidth"src="/images/blog/counting/values.png"alt="The value vectors."width="55%"style="mix-blend-mode: darken;">
139
139
</p>
140
140
141
141
We can verify that by adding n-times the value vector of the 1-token to the value vector of the BoS-token, we arrive at a distribution that (after a softmax) is peaked at n. Comparing this with the output of the model, we see that this is indeed what the network is implementing.
142
142
143
143
<palign="center">
144
-
<imgsrc="/images/blog/counting/formula.png"alt="The value vectors."width="55%"style="mix-blend-mode: darken;">
144
+
<imgclass="fullwidth"src="/images/blog/counting/formula.png"alt="The value vectors."width="55%"style="mix-blend-mode: darken;">
145
145
</p>
146
146
147
147
Therefore, in these models we fully understand what attention patterns the model is using, how these attention patterns are implemented and explicitly how the output of the network is constructed.
@@ -150,7 +150,7 @@ If you made it this far, here is an interesting bonus point:
150
150
151
151
* Even though the model has access to the number n through its attention profile, it still does not construct a probability distribution that is sharply peaked at n. As we see in the above figure, as n gets large, this probability distribution gets wider. This, we believe is partly the side-effect of this specific solution where two curves are being balanced against each other. But it is partly a general problem that as the number of tokens that are attended to gets large, we need higher accuracy to be able to infer n exactly. This is because the information about n is coded non-linearly after the attention layer. In this case, if we assume that the model attends to BoS and 1-tokens equally the output becomes:
152
152
<palign="center">
153
-
<imgsrc="/images/blog/counting/n_dependence.png"alt="The n-dependence of the model output."width="25%"style="mix-blend-mode: darken;">
153
+
<imgclass="halfwidth"src="/images/blog/counting/n_dependence.png"alt="The n-dependence of the model output."width="25%"style="mix-blend-mode: darken;">
154
154
</p>
155
155
We see that as n becomes large, the difference between n and n+1 becomes smaller.
156
156
@@ -165,3 +165,14 @@ For more details, check out the [paper](https://arxiv.org/pdf/2406.02585).
165
165
166
166
167
167
Image by [Tim Mossholder](https://unsplash.com/photos/blue-and-black-electric-wires-FwzhysPCQZc) via Unsplash.
0 commit comments