Hi authors, in models/cache.py (lines 154–159), the code computes the mean of the key vectors in each chunk and then selects the top-k chunks based on the dot product between the query and the mean key. However, this approach is not equivalent to selecting the top-k chunks based on the averaged attention scores after applying softmax to the individual key-query dot products.
Anyways I'm not sure whether this would lead to accumulated errors in the retrieval cache during generation, but hope that what I've noticed would do some help!