1010#include < stdexcept>
1111#include < cinttypes>
1212
13- static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
14- // TODO move to hparams if a T5 variant appears that uses a different value
15- const int64_t max_distance = 128 ;
16-
17- if (bidirectional) {
18- n_buckets >>= 1 ;
19- }
20-
21- const int64_t max_exact = n_buckets >> 1 ;
22-
23- int32_t relative_position = x - y;
24- int32_t relative_bucket = 0 ;
25- if (bidirectional) {
26- relative_bucket += (relative_position > 0 ) * n_buckets;
27- relative_position = abs (relative_position);
28- } else {
29- relative_position = -std::min<int32_t >(relative_position, 0 );
30- }
31- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
32- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
33- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
34- return relative_bucket;
35- }
36-
3713//
3814// llama_context
3915//
@@ -346,6 +322,7 @@ class llama_io_write_dummy : public llama_io_write_i {
346322 return size_written;
347323 }
348324
325+ private:
349326 size_t size_written = 0 ;
350327};
351328
@@ -378,6 +355,7 @@ class llama_io_write_buffer : public llama_io_write_i {
378355 return size_written;
379356 }
380357
358+ private:
381359 uint8_t * ptr;
382360 size_t buf_size = 0 ;
383361 size_t size_written = 0 ;
@@ -406,6 +384,7 @@ class llama_io_read_buffer : public llama_io_read_i {
406384 return size_read;
407385 }
408386
387+ private:
409388 const uint8_t * ptr;
410389 size_t buf_size = 0 ;
411390 size_t size_read = 0 ;
@@ -430,6 +409,7 @@ class llama_io_write_file : public llama_io_write_i {
430409 return size_written;
431410 }
432411
412+ private:
433413 llama_file * file;
434414 size_t size_written = 0 ;
435415 std::vector<uint8_t > temp_buffer;
@@ -454,6 +434,7 @@ class llama_io_read_file : public llama_io_read_i {
454434 return size_read;
455435 }
456436
437+ private:
457438 llama_file * file;
458439 size_t size_read = 0 ;
459440 std::vector<uint8_t > temp_buffer;
@@ -2132,22 +2113,46 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
21322113 GGML_ASSERT (ggml_backend_buffer_is_host (inp_pos_bucket->buffer ));
21332114 GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
21342115
2116+ static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2117+ // TODO move to hparams if a T5 variant appears that uses a different value
2118+ const int64_t max_distance = 128 ;
2119+
2120+ if (bidirectional) {
2121+ n_buckets >>= 1 ;
2122+ }
2123+
2124+ const int64_t max_exact = n_buckets >> 1 ;
2125+
2126+ int32_t relative_position = x - y;
2127+ int32_t relative_bucket = 0 ;
2128+ if (bidirectional) {
2129+ relative_bucket += (relative_position > 0 ) * n_buckets;
2130+ relative_position = abs (relative_position);
2131+ } else {
2132+ relative_position = -std::min<int32_t >(relative_position, 0 );
2133+ }
2134+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
2135+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
2136+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2137+ return relative_bucket;
2138+ };
2139+
21352140 int32_t * data = (int32_t *) inp_pos_bucket->data ;
21362141
21372142 if (!is_encoding) {
21382143 const int64_t n_kv = kv_self.n ;
21392144 for (int h = 0 ; h < 1 ; ++h) {
21402145 for (int j = 0 ; j < n_tokens; ++j) {
21412146 for (int i = 0 ; i < n_kv; ++i) {
2142- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2147+ data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
21432148 }
21442149 }
21452150 }
21462151 } else {
21472152 for (int h = 0 ; h < 1 ; ++h) {
21482153 for (int j = 0 ; j < n_tokens; ++j) {
21492154 for (int i = 0 ; i < n_tokens; ++i) {
2150- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2155+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
21512156 }
21522157 }
21532158 }
0 commit comments