10
10
#include < stdexcept>
11
11
#include < cinttypes>
12
12
13
- static int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
14
- // TODO move to hparams if a T5 variant appears that uses a different value
15
- const int64_t max_distance = 128 ;
16
-
17
- if (bidirectional) {
18
- n_buckets >>= 1 ;
19
- }
20
-
21
- const int64_t max_exact = n_buckets >> 1 ;
22
-
23
- int32_t relative_position = x - y;
24
- int32_t relative_bucket = 0 ;
25
- if (bidirectional) {
26
- relative_bucket += (relative_position > 0 ) * n_buckets;
27
- relative_position = abs (relative_position);
28
- } else {
29
- relative_position = -std::min<int32_t >(relative_position, 0 );
30
- }
31
- int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
32
- relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
33
- relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
34
- return relative_bucket;
35
- }
36
-
37
13
//
38
14
// llama_context
39
15
//
@@ -346,6 +322,7 @@ class llama_io_write_dummy : public llama_io_write_i {
346
322
return size_written;
347
323
}
348
324
325
+ private:
349
326
size_t size_written = 0 ;
350
327
};
351
328
@@ -378,6 +355,7 @@ class llama_io_write_buffer : public llama_io_write_i {
378
355
return size_written;
379
356
}
380
357
358
+ private:
381
359
uint8_t * ptr;
382
360
size_t buf_size = 0 ;
383
361
size_t size_written = 0 ;
@@ -406,6 +384,7 @@ class llama_io_read_buffer : public llama_io_read_i {
406
384
return size_read;
407
385
}
408
386
387
+ private:
409
388
const uint8_t * ptr;
410
389
size_t buf_size = 0 ;
411
390
size_t size_read = 0 ;
@@ -430,6 +409,7 @@ class llama_io_write_file : public llama_io_write_i {
430
409
return size_written;
431
410
}
432
411
412
+ private:
433
413
llama_file * file;
434
414
size_t size_written = 0 ;
435
415
std::vector<uint8_t > temp_buffer;
@@ -454,6 +434,7 @@ class llama_io_read_file : public llama_io_read_i {
454
434
return size_read;
455
435
}
456
436
437
+ private:
457
438
llama_file * file;
458
439
size_t size_read = 0 ;
459
440
std::vector<uint8_t > temp_buffer;
@@ -2132,22 +2113,46 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
2132
2113
GGML_ASSERT (ggml_backend_buffer_is_host (inp_pos_bucket->buffer ));
2133
2114
GGML_ASSERT (!ubatch.equal_seqs ); // TODO: use ubatch.n_seqs instead of failing
2134
2115
2116
+ static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2117
+ // TODO move to hparams if a T5 variant appears that uses a different value
2118
+ const int64_t max_distance = 128 ;
2119
+
2120
+ if (bidirectional) {
2121
+ n_buckets >>= 1 ;
2122
+ }
2123
+
2124
+ const int64_t max_exact = n_buckets >> 1 ;
2125
+
2126
+ int32_t relative_position = x - y;
2127
+ int32_t relative_bucket = 0 ;
2128
+ if (bidirectional) {
2129
+ relative_bucket += (relative_position > 0 ) * n_buckets;
2130
+ relative_position = abs (relative_position);
2131
+ } else {
2132
+ relative_position = -std::min<int32_t >(relative_position, 0 );
2133
+ }
2134
+ int32_t relative_position_if_large = floorf (max_exact + logf (1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log (1.0 * max_distance / max_exact));
2135
+ relative_position_if_large = std::min<int32_t >(relative_position_if_large, n_buckets - 1 );
2136
+ relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2137
+ return relative_bucket;
2138
+ };
2139
+
2135
2140
int32_t * data = (int32_t *) inp_pos_bucket->data ;
2136
2141
2137
2142
if (!is_encoding) {
2138
2143
const int64_t n_kv = kv_self.n ;
2139
2144
for (int h = 0 ; h < 1 ; ++h) {
2140
2145
for (int j = 0 ; j < n_tokens; ++j) {
2141
2146
for (int i = 0 ; i < n_kv; ++i) {
2142
- data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2147
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket (kv_self.cells [i].pos , ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2143
2148
}
2144
2149
}
2145
2150
}
2146
2151
} else {
2147
2152
for (int h = 0 ; h < 1 ; ++h) {
2148
2153
for (int j = 0 ; j < n_tokens; ++j) {
2149
2154
for (int i = 0 ; i < n_tokens; ++i) {
2150
- data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2155
+ data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket (ubatch.pos [i], ubatch.pos [j], hparams.n_rel_attn_bkts , is_encoding);
2151
2156
}
2152
2157
}
2153
2158
}
0 commit comments