-
Probably, a trivial question, but I can't the right keywords to find an answer. On one hand, On another hand, attention in transformers operates on sequences of difference length. Of course, it's possible to pad all sequences to the maximum length, but if, for example, my maximum is 4096 tokens and real sequence is 30 tokens, then it's a huge waste of computational resources. How do people solve this dilemma? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Incremental padding at sizes that are either power of 2, or multiples of something else. |
Beta Was this translation helpful? Give feedback.
In general, the strategy for handling dynamic array sizes in JAX is, depending on the situation, to either dispatch a new JIT-compiled operation for each size, or to pad all entries to a maximum size. In practice, the extra calculations done in the padding strategy may not be as important as you fear, especially if you're running on an accelerator like GPU or TPU.