Skip to content

Commit e08f38d

Browse files
committed
context : minor cleanup
ggml-ci
1 parent f7c7757 commit e08f38d

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

src/llama-context.cpp

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,6 @@
1010
#include <stdexcept>
1111
#include <cinttypes>
1212

13-
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
14-
// TODO move to hparams if a T5 variant appears that uses a different value
15-
const int64_t max_distance = 128;
16-
17-
if (bidirectional) {
18-
n_buckets >>= 1;
19-
}
20-
21-
const int64_t max_exact = n_buckets >> 1;
22-
23-
int32_t relative_position = x - y;
24-
int32_t relative_bucket = 0;
25-
if (bidirectional) {
26-
relative_bucket += (relative_position > 0) * n_buckets;
27-
relative_position = abs(relative_position);
28-
} else {
29-
relative_position = -std::min<int32_t>(relative_position, 0);
30-
}
31-
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
32-
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
33-
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
34-
return relative_bucket;
35-
}
36-
3713
//
3814
// llama_context
3915
//
@@ -346,6 +322,7 @@ class llama_io_write_dummy : public llama_io_write_i {
346322
return size_written;
347323
}
348324

325+
private:
349326
size_t size_written = 0;
350327
};
351328

@@ -378,6 +355,7 @@ class llama_io_write_buffer : public llama_io_write_i {
378355
return size_written;
379356
}
380357

358+
private:
381359
uint8_t * ptr;
382360
size_t buf_size = 0;
383361
size_t size_written = 0;
@@ -406,6 +384,7 @@ class llama_io_read_buffer : public llama_io_read_i {
406384
return size_read;
407385
}
408386

387+
private:
409388
const uint8_t * ptr;
410389
size_t buf_size = 0;
411390
size_t size_read = 0;
@@ -430,6 +409,7 @@ class llama_io_write_file : public llama_io_write_i {
430409
return size_written;
431410
}
432411

412+
private:
433413
llama_file * file;
434414
size_t size_written = 0;
435415
std::vector<uint8_t> temp_buffer;
@@ -454,6 +434,7 @@ class llama_io_read_file : public llama_io_read_i {
454434
return size_read;
455435
}
456436

437+
private:
457438
llama_file * file;
458439
size_t size_read = 0;
459440
std::vector<uint8_t> temp_buffer;
@@ -2132,22 +2113,46 @@ void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
21322113
GGML_ASSERT(ggml_backend_buffer_is_host(inp_pos_bucket->buffer));
21332114
GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
21342115

2116+
static const auto relative_position_bucket = [](llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
2117+
// TODO move to hparams if a T5 variant appears that uses a different value
2118+
const int64_t max_distance = 128;
2119+
2120+
if (bidirectional) {
2121+
n_buckets >>= 1;
2122+
}
2123+
2124+
const int64_t max_exact = n_buckets >> 1;
2125+
2126+
int32_t relative_position = x - y;
2127+
int32_t relative_bucket = 0;
2128+
if (bidirectional) {
2129+
relative_bucket += (relative_position > 0) * n_buckets;
2130+
relative_position = abs(relative_position);
2131+
} else {
2132+
relative_position = -std::min<int32_t>(relative_position, 0);
2133+
}
2134+
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
2135+
relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
2136+
relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
2137+
return relative_bucket;
2138+
};
2139+
21352140
int32_t * data = (int32_t *) inp_pos_bucket->data;
21362141

21372142
if (!is_encoding) {
21382143
const int64_t n_kv = kv_self.n;
21392144
for (int h = 0; h < 1; ++h) {
21402145
for (int j = 0; j < n_tokens; ++j) {
21412146
for (int i = 0; i < n_kv; ++i) {
2142-
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
2147+
data[h*(n_kv*n_tokens) + j*n_kv + i] = relative_position_bucket(kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
21432148
}
21442149
}
21452150
}
21462151
} else {
21472152
for (int h = 0; h < 1; ++h) {
21482153
for (int j = 0; j < n_tokens; ++j) {
21492154
for (int i = 0; i < n_tokens; ++i) {
2150-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
2155+
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, is_encoding);
21512156
}
21522157
}
21532158
}

0 commit comments

Comments
 (0)