You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2024-05-30-counting.md
+12-27Lines changed: 12 additions & 27 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -23,24 +23,15 @@ In this blog post, we summarize a recent paper which is part of an ongoing effor
23
23
<br>
24
24
In this task, the input is a sequence composed of zeros, ones, and square bracket delimiters: `{0, 1, [, ]}`. Each sample sequence contains ones and zeros with several regions marked by the delimiters. The task is to count the number of ones within each delimited region. For example, given the sequence:
25
25
26
-
```
27
-
input = [ 0 ] [ 1 0 1 ] 0 [ 1 ] 1 [ ] 0
28
-
```
29
-
30
-
The target output would be:
31
-
32
-
```
33
-
target = [0, 2, 1, 0]
34
-
```
35
-
26
+
<palign="center">
27
+
<imgsrc="/images/blog/counting/input.png"alt="Example of input."width="52%"style="mix-blend-mode: darken;">
28
+
</p>
36
29
For simplicity, the regions are non-overlapping.
37
30
38
31
To tackle this task using a Transformer architecture, we use an encoder-decoder setup and provide the decoder with a fixed prompt that labels the regions. For the example above, the prompt would be:
39
-
40
-
```
41
-
prompt = [0, 1, 2, 3]
42
-
```
43
-
32
+
<palign="center">
33
+
<imgsrc="/images/blog/counting/prompt.png"alt="Example of prompt."width="52%"style="mix-blend-mode: darken;">
34
+
</p>
44
35
For our experiments, we fix the number of regions to 4 and the sequence length to 512. This allows us to explore how solutions generalize to different numbers of regions and sequence lengths.
45
36
46
37
#### Relevance
@@ -63,12 +54,9 @@ We provide some theoretical insights into the problem, showing that a Transforme
63
54
#### Contextual Position (CP)
64
55
65
56
Contextual position refers to positional information in a sequence that is meaningful only within the context of the problem. For the contextual counting task, this means knowing the region number for each token. For example, with three regions, the input and contextual positions might look like:
66
-
67
-
```
68
-
input = 0 1 1 [ 1 0 1 ] 0 [ 1 ] 1 [ ] 0
69
-
CP = - - - 1 1 1 1 1 - 2 2 2 - 3 3 -
70
-
```
71
-
57
+
<palign="center">
58
+
<imgsrc="/images/blog/counting/CP.png"alt="Example of contextual position."width="52%"style="mix-blend-mode: darken;">
59
+
</p>
72
60
This information helps disambiguate the different regions based on context.
73
61
74
62
#### Key Propositions
@@ -103,7 +91,7 @@ We summarize the results of this empirical exploration below.
<imgsrc="/images/blog/counting/accuracy.png"alt="Performance of the different configuraiton"width="55%"style="mix-blend-mode: darken;">
94
+
<imgsrc="/images/blog/counting/accuracy.png"alt="Performance of the different configuraiton"width="65%"style="mix-blend-mode: darken;">
107
95
</p>
108
96
109
97
The above figure shows the performance of different Transformer configurations. The most prominant feature of this figure is that non-causal transformers with any positional encoding fail to get good performance. In contrast, causal Transformers can achieve close to 100\% accuracy.
@@ -120,7 +108,6 @@ As described above, the regional contextual position is an important piece of in
120
108
<palign="center">
121
109
<imgsrc="/images/blog/counting/pca_proj.png"alt="PCA projection of the 1-tokens after the encoder layer."width="45%"style="mix-blend-mode: darken;">
122
110
</p>
123
-
124
111
By looking at the details of the attention module of the encoder, we see that in causal models, this information is inferred by attending to all the previous delimiter tokens equally. Each token can tell which region it is in by looking at how many delimiter tokens of each kind preceded it.
125
112
126
113
@@ -129,7 +116,7 @@ By looking at the details of the attention module of the encoder, we see that in
129
116
We can verify explicitly that the inferred regional contextual position in the encoder is used in the decoder cross-attention module such that the attention profile is focused on the 1-tokens of the relevant region (in the below figure, the third region).
130
117
131
118
<palign="center">
132
-
<imgsrc="/images/blog/counting/decoder.png"alt="The attention profile of the decoder."width="45%"style="mix-blend-mode: darken;">
119
+
<imgsrc="/images/blog/counting/decoder.png"alt="The attention profile of the decoder."width="55%"style="mix-blend-mode: darken;">
133
120
</p>
134
121
135
122
We see that in this example, the decoder also attends to the beginning-of-sequence token. The reason for this is that, if the model *only* attends to the 1-tokens, then the number of the 1-tokens - the quantity of interest - is going to cancel in the calculation of the softmax. However, if there is another token, then the number of 1-tokens will be preserved. In this way, this other token acts as a bias term when computing the output of the attention module.
@@ -147,7 +134,7 @@ We can get a hint at what might be the culprit by looking at the attention patte
147
134
The figure below, shows the attention pattern of the orange dots, i.e. the model that generalizes do different seuqence lengths but not to different region numbers. We see that as before, the decoder pays attention to the 1-tokens of the relevant region (in this case the first region), however this time the role of the bias term is played by the ]-tokens. During training, the number of regions is fixed at 4, and therefore the number of ]-tokens can be used as a constant bias. However, this is not the case when the number of regions changes. This explains why this model does not generalize to other number of regions.
148
135
149
136
<palign="center">
150
-
<imgsrc="/images/blog/counting/decoder_nongen.png"alt="The attention profile of the decoder of a non-generalizing model."width="45%"style="mix-blend-mode: darken;">
137
+
<imgsrc="/images/blog/counting/decoder_nongen.png"alt="The attention profile of the decoder of a non-generalizing model."width="55%"style="mix-blend-mode: darken;">
151
138
</p>
152
139
153
140
In our exploration, we found that the model can use any combination of quantities that are constant during training as biases.
@@ -175,11 +162,9 @@ Therefore, in these models we fully understand what attention patterns the model
175
162
If you made it this far, here is an interesting bonus point:
176
163
177
164
* Even though the model has access to the number n through its attention profile, it still does not construct a probability distribution that is sharply peaked at n. As we see in the above figure, as n gets large, this probability distribution gets wider. This, we believe is partly the side-effect of this specific solution where two curves are being balanced against each other. But it is partly a general problem that as the number of tokens that are attended to gets large, we need higher accuracy to be able to infer n exactly. This is because the information about n is coded non-linearly after the attention layer. In this case, if we assume that the model attends to BoS and 1-tokens equally the output becomes:
178
-
179
165
<palign="center">
180
166
<imgsrc="/images/blog/counting/n_dependence.png"alt="The n-dependence of the model output."width="25%"style="mix-blend-mode: darken;">
181
167
</p>
182
-
183
168
We see that as n becomes large, the difference between n and n+1 becomes smaller.
0 commit comments