Skip to content

Commit 5832d82

Browse files
committed
Merge branch 'main' of github.com:PolymathicAI/PolymathicAI.github.io
2 parents 6ccf1de + b490ee6 commit 5832d82

10 files changed

+119
-16
lines changed

_config.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ team:
357357

358358
- full_name: Siavash Golkar
359359
avatar: siavash_golkar.jpg
360+
website: http://siavashgolkar.com/
360361
bio: Siavash Golkar is a research scientist working in the field of machine learning and applications to scientific domains. Siavash received his PhD in 2015 in the field of High Energy Physics, working on topological states of matter. He has since worked as a postdoc at Cambridge University and New York University and as an associate research scientist at the Flatiron Institute. His recent work spans research in ML from continual and transfer learning to applying large transformer models to numerical and scientific datasets.
361362

362363
- full_name: Keiya Hirashima
@@ -411,7 +412,7 @@ team:
411412

412413
- full_name: Bruno Regaldo
413414
avatar: bruno_regaldo_saint_blancard.jpg
414-
website: https://users.flatironinstitute.org/~bregaldosaintblancard/
415+
website: https://bregaldo.github.io/
415416
bio: Bruno Régaldo-Saint Blancard is a Research Fellow at the Center for Computational Mathematics, Flatiron Institute. He obtained a PhD in Astrophysics from the École Normale Supérieure (ENS), Paris. Prior to that, he graduated from the École Polytechnique, and obtained a M.S. in Astrophysics from the Observatoire de Paris. Bruno’s research focuses on the development of statistical methods for astrophysics/cosmology and beyond, using signal processing and machine learning. He is interested in various problems including generative modeling, inference, denoising, and source separation.
416417

417418

_posts/2024-05-30-counting.md

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
layout: post
33
title: "How Do Transformers Count in Context?"
44
authors: Siavash Golkar, Alberto Bietti, Mariel Pettee, Michael Eickenberg, Miles Cranmer, Keiya Hirashima, Geraud Krawezik, Nicholas Lourie, Michael McCabe, Rudy Morel, Ruben Ohana, Liam Holden Parker, Bruno Régaldo-Saint Blancard, Kyunghyun Cho, Shirley Ho
5-
shorttitle: "Counting in Context?"
5+
shorttitle: "Counting in Context"
66
date: 2024-05-30 09:23
77
image: counting-splash.jpg
88
smallimage: counting-s.jpg
@@ -24,13 +24,13 @@ To understand how Transformers solve complex problems, it helps to start with si
2424
In this task, the input is a sequence composed of zeros, ones, and square bracket delimiters: `{0, 1, [, ]}`. Each sample sequence contains ones and zeros with several regions marked by the delimiters. The task is to count the number of ones within each delimited region. For example, given the sequence:
2525

2626
<p align="center">
27-
<img src="/images/blog/counting/input.png" alt="Example of input." width="52%" style="mix-blend-mode: darken;">
27+
<img class="fullwidth" src="/images/blog/counting/input.png" alt="Example of input." width="52%" style="mix-blend-mode: darken;">
2828
</p>
2929
For simplicity, the regions are non-overlapping.
3030

3131
To tackle this task using a Transformer architecture, we use an encoder-decoder setup and provide the decoder with a fixed prompt that labels the regions. For the example above, the prompt would be:
3232
<p align="center">
33-
<img src="/images/blog/counting/prompt.png" alt="Example of prompt." width="52%" style="mix-blend-mode: darken;">
33+
<img class="fullwidth" src="/images/blog/counting/prompt.png" alt="Example of prompt." width="52%" style="mix-blend-mode: darken;">
3434
</p>
3535
For our experiments, we fix the number of regions to 4 and the sequence length to 512. This allows us to explore how solutions generalize to different numbers of regions and sequence lengths.
3636

@@ -44,14 +44,14 @@ We provide some theoretical insights into the problem, showing that a Transforme
4444

4545
Contextual position refers to positional information in a sequence that is meaningful only within the context of the problem. For the contextual counting task, this means knowing the region number for each token. For example, with three regions, the input and contextual positions might look like:
4646
<p align="center">
47-
<img src="/images/blog/counting/CP.png" alt="Example of contextual position." width="52%" style="mix-blend-mode: darken;">
47+
<img class="fullwidth" src="/images/blog/counting/CP.png" alt="Example of contextual position." width="52%" style="mix-blend-mode: darken;">
4848
</p>
4949
This information helps disambiguate the different regions based on context.
5050

5151
#### Key Propositions
5252

53-
1. **Proposition 1:** If the regional contextual position information is available in the latent representation of the tokens at some layer of a Transformer, the contextual counting task can be solved with a single additional layer.
54-
2. **Proposition 2:** A causal Transformer with a single layer and no position encoding (NoPE) can infer the regional contextual position.
53+
- **Proposition 1:** If the regional contextual position information is available in the latent representation of the tokens at some layer of a Transformer, the contextual counting task can be solved with a single additional layer.
54+
- **Proposition 2:** A causal Transformer with a single layer and no position encoding (NoPE) can infer the regional contextual position.
5555

5656
These propositions imply that a two-layer causal Transformer with NoPE can solve the contextual counting task.
5757

@@ -71,7 +71,7 @@ These propositions highlight the difficulties non-causal Transformers face in so
7171
The theoretical results above imply that exact solutions exist but do not clarify whether or not such solutions can indeed be found when the model is trained via SGD. We therefore trained various Transformer architectures on this task. Inspired by the theoretical arguments, we use an encoder-decoder architecture, with one layer and one head for each. A typical output of the network is shown in the following image where the model outputs the probability distribution over the number of ones in each region.
7272

7373
<p align="center">
74-
<img src="/images/blog/counting/output.png" alt="Typical output of the model" width="55%" style="mix-blend-mode: darken;">
74+
<img class="fullwidth" src="/images/blog/counting/output.png" alt="Typical output of the model" width="55%" style="mix-blend-mode: darken;">
7575
</p>
7676

7777

@@ -80,7 +80,7 @@ We summarize the results of this empirical exploration below.
8080
#### 1. Causal Transformers significantly outperform non-causal ones.
8181

8282
<p align="center">
83-
<img src="/images/blog/counting/accuracy.png" alt="Performance of the different configuration" width="65%" style="mix-blend-mode: darken;">
83+
<img class="fullwidth" src="/images/blog/counting/accuracy.png" alt="Performance of the different configuration" width="65%" style="mix-blend-mode: darken;">
8484
</p>
8585

8686
The above figure shows the performance of different Transformer configurations. The most prominent feature of this figure is that non-causal Transformers with any positional encoding fail to get good performance. In contrast, causal Transformers can achieve close to 100\% accuracy.
@@ -95,7 +95,7 @@ We also see that the very best model is trained with NoPE but RoPE is much more
9595
As described above, the regional contextual position is an important piece of information for this task. Looking at the projection of the 1-token embeddings in the different regions, we see that this information is accurately captured.
9696

9797
<p align="center">
98-
<img src="/images/blog/counting/pca_proj.png" alt="PCA projection of the 1-tokens after the encoder layer." width="45%" style="mix-blend-mode: darken;">
98+
<img class="fullwidth" src="/images/blog/counting/pca_proj.png" alt="PCA projection of the 1-tokens after the encoder layer." width="45%" style="mix-blend-mode: darken;">
9999
</p>
100100
By looking at the details of the attention module of the encoder, we see that in causal models, this information is inferred by attending to all the previous delimiter tokens equally. Each token can tell which region it is in by looking at how many delimiter tokens of each kind preceded it.
101101

@@ -105,7 +105,7 @@ By looking at the details of the attention module of the encoder, we see that in
105105
We can verify explicitly that the inferred regional contextual position in the encoder is used in the decoder cross-attention module such that the attention profile is focused on the 1-tokens of the relevant region (in the below figure, the third region).
106106

107107
<p align="center">
108-
<img src="/images/blog/counting/decoder.png" alt="The attention profile of the decoder." width="55%" style="mix-blend-mode: darken;">
108+
<img class="fullwidth" src="/images/blog/counting/decoder.png" alt="The attention profile of the decoder." width="55%" style="mix-blend-mode: darken;">
109109
</p>
110110

111111
We see that in this example, the decoder also attends to the beginning-of-sequence token. The reason for this is that, if the model *only* attends to the 1-tokens, then the number of the 1-tokens - the quantity of interest - is going to cancel in the calculation of the softmax. However, if there is another token, then the number of 1-tokens will be preserved. In this way, this other token acts as a bias term when computing the output of the attention module.
@@ -115,15 +115,15 @@ We see that in this example, the decoder also attends to the beginning-of-sequen
115115
The figure below shows the behavior of three different type of solutions when generalizing to sequences of different lengths and inputs with different number of regions. Even though all three attain the same performance on the in-distribution data, their out-of-distribution performance is very different. Why is this the case?
116116

117117
<p align="center">
118-
<img src="/images/blog/counting/var_sols.png" alt="Different types of solutions." width="95%" style="mix-blend-mode: darken;">
118+
<img class="fullwidth" src="/images/blog/counting/var_sols.png" alt="Different types of solutions." width="95%" style="mix-blend-mode: darken;">
119119
</p>
120120

121121
We can get a hint at what might be the culprit by looking at the attention pattern of the decoder. The attention pattern given in the previous point pertains to the blue dots on this figure, i.e. the model that generalizes best.
122122

123123
The figure below, shows the attention pattern of the orange dots, i.e. the model that generalizes do different sequence lengths but not to different region numbers. We see that as before, the decoder pays attention to the 1-tokens of the relevant region (in this case the first region), however this time the role of the bias term is played by the ]-tokens. During training, the number of regions is fixed at 4, and therefore the number of ]-tokens can be used as a constant bias. However, this is not the case when the number of regions changes. This explains why this model does not generalize to other number of regions.
124124

125125
<p align="center">
126-
<img src="/images/blog/counting/decoder_nongen.png" alt="The attention profile of the decoder of a non-generalizing model." width="55%" style="mix-blend-mode: darken;">
126+
<img class="fullwidth" src="/images/blog/counting/decoder_nongen.png" alt="The attention profile of the decoder of a non-generalizing model." width="55%" style="mix-blend-mode: darken;">
127127
</p>
128128

129129
In our exploration, we found that the model can use any combination of quantities that are constant during training as biases.
@@ -135,13 +135,13 @@ In some of our experiments, we chose to remove the MLP and self-attention layers
135135
In a previous case we saw that the decoder only attended to the 1-tokens of the relevant region and the beginning-of-sequence token. The figure below shows the value vectors of these two tokens.
136136

137137
<p align="center">
138-
<img src="/images/blog/counting/values.png" alt="The value vectors." width="55%" style="mix-blend-mode: darken;">
138+
<img class="fullwidth" src="/images/blog/counting/values.png" alt="The value vectors." width="55%" style="mix-blend-mode: darken;">
139139
</p>
140140

141141
We can verify that by adding n-times the value vector of the 1-token to the value vector of the BoS-token, we arrive at a distribution that (after a softmax) is peaked at n. Comparing this with the output of the model, we see that this is indeed what the network is implementing.
142142

143143
<p align="center">
144-
<img src="/images/blog/counting/formula.png" alt="The value vectors." width="55%" style="mix-blend-mode: darken;">
144+
<img class="fullwidth" src="/images/blog/counting/formula.png" alt="The value vectors." width="55%" style="mix-blend-mode: darken;">
145145
</p>
146146

147147
Therefore, in these models we fully understand what attention patterns the model is using, how these attention patterns are implemented and explicitly how the output of the network is constructed.
@@ -150,7 +150,7 @@ If you made it this far, here is an interesting bonus point:
150150

151151
* Even though the model has access to the number n through its attention profile, it still does not construct a probability distribution that is sharply peaked at n. As we see in the above figure, as n gets large, this probability distribution gets wider. This, we believe is partly the side-effect of this specific solution where two curves are being balanced against each other. But it is partly a general problem that as the number of tokens that are attended to gets large, we need higher accuracy to be able to infer n exactly. This is because the information about n is coded non-linearly after the attention layer. In this case, if we assume that the model attends to BoS and 1-tokens equally the output becomes:
152152
<p align="center">
153-
<img src="/images/blog/counting/n_dependence.png" alt="The n-dependence of the model output." width="25%" style="mix-blend-mode: darken;">
153+
<img class="halfwidth" src="/images/blog/counting/n_dependence.png" alt="The n-dependence of the model output." width="25%" style="mix-blend-mode: darken;">
154154
</p>
155155
We see that as n becomes large, the difference between n and n+1 becomes smaller.
156156

@@ -165,3 +165,14 @@ For more details, check out the [paper](https://arxiv.org/pdf/2406.02585).
165165

166166

167167
Image by [Tim Mossholder](https://unsplash.com/photos/blue-and-black-electric-wires-FwzhysPCQZc) via Unsplash.
168+
169+
<style>
170+
@media (max-width: 767px) {
171+
.fullwidth {
172+
width: 100% !important;
173+
}
174+
.halfwidth{
175+
width: 60% !important;
176+
}
177+
}
178+
</style>

0 commit comments

Comments
 (0)