Skip to content

Commit 36f4f13

Browse files
committed
added images for input,cp,prompt
1 parent f4d3d23 commit 36f4f13

File tree

5 files changed

+13
-28
lines changed

5 files changed

+13
-28
lines changed

_layouts/post.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
</p>
1616

1717
<h3 class="title">{{ page.title }}</h3>
18-
<h4 class="category">{{ page.date | date: "%b" }}, {{ page.date | date: "%d" }} {{ page.date | date: "%Y" }}</h4>
18+
<h4 class="category">{{ page.date | date: "%b" }} {{ page.date | date: "%d" }}, {{ page.date | date: "%Y" }}</h4>
1919
</div>
2020
</header>
2121

_posts/2024-05-30-counting.md

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,15 @@ In this blog post, we summarize a recent paper which is part of an ongoing effor
2323
<br>
2424
In this task, the input is a sequence composed of zeros, ones, and square bracket delimiters: `{0, 1, [, ]}`. Each sample sequence contains ones and zeros with several regions marked by the delimiters. The task is to count the number of ones within each delimited region. For example, given the sequence:
2525

26-
```
27-
input = [ 0 ] [ 1 0 1 ] 0 [ 1 ] 1 [ ] 0
28-
```
29-
30-
The target output would be:
31-
32-
```
33-
target = [0, 2, 1, 0]
34-
```
35-
26+
<p align="center">
27+
<img src="/images/blog/counting/input.png" alt="Example of input." width="52%" style="mix-blend-mode: darken;">
28+
</p>
3629
For simplicity, the regions are non-overlapping.
3730

3831
To tackle this task using a Transformer architecture, we use an encoder-decoder setup and provide the decoder with a fixed prompt that labels the regions. For the example above, the prompt would be:
39-
40-
```
41-
prompt = [0, 1, 2, 3]
42-
```
43-
32+
<p align="center">
33+
<img src="/images/blog/counting/prompt.png" alt="Example of prompt." width="52%" style="mix-blend-mode: darken;">
34+
</p>
4435
For our experiments, we fix the number of regions to 4 and the sequence length to 512. This allows us to explore how solutions generalize to different numbers of regions and sequence lengths.
4536

4637
#### Relevance
@@ -63,12 +54,9 @@ We provide some theoretical insights into the problem, showing that a Transforme
6354
#### Contextual Position (CP)
6455

6556
Contextual position refers to positional information in a sequence that is meaningful only within the context of the problem. For the contextual counting task, this means knowing the region number for each token. For example, with three regions, the input and contextual positions might look like:
66-
67-
```
68-
input = 0 1 1 [ 1 0 1 ] 0 [ 1 ] 1 [ ] 0
69-
CP = - - - 1 1 1 1 1 - 2 2 2 - 3 3 -
70-
```
71-
57+
<p align="center">
58+
<img src="/images/blog/counting/CP.png" alt="Example of contextual position." width="52%" style="mix-blend-mode: darken;">
59+
</p>
7260
This information helps disambiguate the different regions based on context.
7361

7462
#### Key Propositions
@@ -103,7 +91,7 @@ We summarize the results of this empirical exploration below.
10391
#### 1. Causal transformers significantly outperform non-causal ones.
10492

10593
<p align="center">
106-
<img src="/images/blog/counting/accuracy.png" alt="Performance of the different configuraiton" width="55%" style="mix-blend-mode: darken;">
94+
<img src="/images/blog/counting/accuracy.png" alt="Performance of the different configuraiton" width="65%" style="mix-blend-mode: darken;">
10795
</p>
10896

10997
The above figure shows the performance of different Transformer configurations. The most prominant feature of this figure is that non-causal transformers with any positional encoding fail to get good performance. In contrast, causal Transformers can achieve close to 100\% accuracy.
@@ -120,7 +108,6 @@ As described above, the regional contextual position is an important piece of in
120108
<p align="center">
121109
<img src="/images/blog/counting/pca_proj.png" alt="PCA projection of the 1-tokens after the encoder layer." width="45%" style="mix-blend-mode: darken;">
122110
</p>
123-
124111
By looking at the details of the attention module of the encoder, we see that in causal models, this information is inferred by attending to all the previous delimiter tokens equally. Each token can tell which region it is in by looking at how many delimiter tokens of each kind preceded it.
125112

126113

@@ -129,7 +116,7 @@ By looking at the details of the attention module of the encoder, we see that in
129116
We can verify explicitly that the inferred regional contextual position in the encoder is used in the decoder cross-attention module such that the attention profile is focused on the 1-tokens of the relevant region (in the below figure, the third region).
130117

131118
<p align="center">
132-
<img src="/images/blog/counting/decoder.png" alt="The attention profile of the decoder." width="45%" style="mix-blend-mode: darken;">
119+
<img src="/images/blog/counting/decoder.png" alt="The attention profile of the decoder." width="55%" style="mix-blend-mode: darken;">
133120
</p>
134121

135122
We see that in this example, the decoder also attends to the beginning-of-sequence token. The reason for this is that, if the model *only* attends to the 1-tokens, then the number of the 1-tokens - the quantity of interest - is going to cancel in the calculation of the softmax. However, if there is another token, then the number of 1-tokens will be preserved. In this way, this other token acts as a bias term when computing the output of the attention module.
@@ -147,7 +134,7 @@ We can get a hint at what might be the culprit by looking at the attention patte
147134
The figure below, shows the attention pattern of the orange dots, i.e. the model that generalizes do different seuqence lengths but not to different region numbers. We see that as before, the decoder pays attention to the 1-tokens of the relevant region (in this case the first region), however this time the role of the bias term is played by the ]-tokens. During training, the number of regions is fixed at 4, and therefore the number of ]-tokens can be used as a constant bias. However, this is not the case when the number of regions changes. This explains why this model does not generalize to other number of regions.
148135

149136
<p align="center">
150-
<img src="/images/blog/counting/decoder_nongen.png" alt="The attention profile of the decoder of a non-generalizing model." width="45%" style="mix-blend-mode: darken;">
137+
<img src="/images/blog/counting/decoder_nongen.png" alt="The attention profile of the decoder of a non-generalizing model." width="55%" style="mix-blend-mode: darken;">
151138
</p>
152139

153140
In our exploration, we found that the model can use any combination of quantities that are constant during training as biases.
@@ -175,11 +162,9 @@ Therefore, in these models we fully understand what attention patterns the model
175162
If you made it this far, here is an interesting bonus point:
176163

177164
* Even though the model has access to the number n through its attention profile, it still does not construct a probability distribution that is sharply peaked at n. As we see in the above figure, as n gets large, this probability distribution gets wider. This, we believe is partly the side-effect of this specific solution where two curves are being balanced against each other. But it is partly a general problem that as the number of tokens that are attended to gets large, we need higher accuracy to be able to infer n exactly. This is because the information about n is coded non-linearly after the attention layer. In this case, if we assume that the model attends to BoS and 1-tokens equally the output becomes:
178-
179165
<p align="center">
180166
<img src="/images/blog/counting/n_dependence.png" alt="The n-dependence of the model output." width="25%" style="mix-blend-mode: darken;">
181167
</p>
182-
183168
We see that as n becomes large, the difference between n and n+1 becomes smaller.
184169

185170
<br>

images/blog/counting/CP.png

26.6 KB
Loading

images/blog/counting/input.png

29 KB
Loading

images/blog/counting/prompt.png

18.1 KB
Loading

0 commit comments

Comments
 (0)