Skip to content

Commit 102dd30

Browse files
committed
Latest commits, added candidates strings display
1 parent 429d69f commit 102dd30

40 files changed

+2365
-531
lines changed

base_sampling2/chat_layer.h

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ extern int num_probs_bottoms;
7373

7474
extern float confidence_total;
7575

76+
extern std::vector<llama_token> last_candidates_logits;
77+
7678
#define SESSIONS_FOLDER "sessions/"
7779

7880
static common_params paramsDefault;
@@ -190,6 +192,7 @@ class chat
190192
private:
191193

192194
llama_context * ctx = nullptr;
195+
llama_memory_t mem = nullptr;
193196
llama_model * model = nullptr;
194197
common_sampler * smpl = nullptr;
195198
const llama_vocab * vocab = nullptr;
@@ -289,6 +292,8 @@ class chat
289292
std::string logit_bias_strings_ext_display = "";
290293
std::string logit_bias_strings_start_display = "";
291294

295+
std::string last_candidates_logits_display = "";
296+
292297
struct llama_perf_context_data ctx_performance_data;
293298

294299
//std::map<std::string,std::string> stats;
@@ -765,6 +770,14 @@ class chat
765770
}
766771
}
767772

773+
void get_last_candidates_logits_display() {
774+
last_candidates_logits_display.clear();
775+
776+
for (auto logit : last_candidates_logits) {
777+
last_candidates_logits_display += std::format("{}; ", common_token_to_piece(ctx, logit));
778+
}
779+
}
780+
768781
void params_postfill() {
769782
if (params.kv_overrides_pair.size()) kv_override_prefill();
770783
common_process_override_tensors(params);
@@ -1250,6 +1263,9 @@ class chat
12501263
ctx = llama_init.context.release();
12511264
printf("..............CONTEXT INITIALIZED (%s)................\n", __func__);
12521265

1266+
mem = llama_get_memory(ctx);
1267+
printf("..............MEM INITIALIZED (%s)................\n", __func__);
1268+
12531269
assignThreads();
12541270
printf("..............THREADS ASSIGNED (%s)................\n", __func__);
12551271

@@ -1402,7 +1418,7 @@ class chat
14021418

14031419
// remove any "future" tokens that we might have inherited from the previous session
14041420
//llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
1405-
llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
1421+
llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
14061422
}
14071423

14081424
// if we will use the cache for the full prompt without reaching the end of the cache, force
@@ -1475,8 +1491,8 @@ class chat
14751491
// always keep the first token - BOS
14761492
//n_past = std::max(1, params.n_keep);
14771493
//n_past_guidance = std::max(1, params.n_keep + guidance_offset);
1478-
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
1479-
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
1494+
llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard);
1495+
llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard);
14801496

14811497
// insert n_left/2 tokens at the start of embd from last_n_tokens
14821498
//embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@@ -1510,8 +1526,8 @@ class chat
15101526
const int n_left = n_past - params.n_keep;
15111527
const int n_discard = n_left/2;
15121528

1513-
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
1514-
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
1529+
llama_memory_seq_rm (mem, 0, params.n_keep , params.n_keep + n_discard);
1530+
llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard);
15151531

15161532
n_past -= n_discard;
15171533

@@ -1524,9 +1540,9 @@ class chat
15241540
const int bd = (ga_w/ga_n)*(ga_n - 1);
15251541
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
15261542

1527-
llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
1528-
llama_kv_self_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
1529-
llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
1543+
llama_memory_seq_add(mem, 0, ga_i, n_past, ib*bd);
1544+
llama_memory_seq_div(mem, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
1545+
llama_memory_seq_add(mem, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
15301546

15311547
n_past -= bd;
15321548

@@ -1650,6 +1666,8 @@ class chat
16501666
// const llama_token id = common_sampler_sample(smpl, ctx, -1);
16511667
llama_token id = common_sampler_sample(smpl, ctx, -1);
16521668

1669+
get_last_candidates_logits_display();
1670+
16531671
// try to sample a different token to avoid empty messages
16541672
int attempts = 1000; // safeguard
16551673
while (emptyMessage == true && llama_token_is_eog(vocab, id) && attempts > 0) {
@@ -1738,7 +1756,7 @@ class chat
17381756
capture_smpl();
17391757
// rewind_state.capture_kv_cache(llama_kv_cache_seq_pos_max(ctx, 0));
17401758
// rewind_state.capture_kv_cache(llama_kv_self_seq_pos_max(ctx, -1));
1741-
rewind_state.capture_kv_cache(llama_kv_self_seq_pos_max(ctx, 0));
1759+
rewind_state.capture_kv_cache(llama_memory_seq_pos_max(mem, 0));
17421760
rewind_state.capture_embd_inp(embd_inp.size());
17431761
rewind_state.capture_n_past(n_past);
17441762
rewind_state.capture_n_consumed(n_consumed);
@@ -1748,7 +1766,7 @@ class chat
17481766
int get_kv_cache_seq_pos_max() {
17491767
// return llama_kv_cache_seq_pos_max(ctx, 0);
17501768
// return llama_kv_self_seq_pos_max(ctx, -1);
1751-
return llama_kv_self_seq_pos_max(ctx, 0);
1769+
return llama_memory_seq_pos_max(mem, 0);
17521770
}
17531771

17541772
void clearStates2() {
@@ -1764,7 +1782,7 @@ class chat
17641782
restore_smpl();
17651783
//common_sampler_reset(smpl);
17661784
// context
1767-
llama_kv_self_seq_rm(ctx, 0, rewind_state.kv_cache_pos, -1);
1785+
llama_memory_seq_rm(mem, 0, rewind_state.kv_cache_pos, -1);
17681786
// llama_kv_self_seq_rm(ctx, -1, rewind_state.kv_cache_pos, -1);
17691787
// llama_kv_cache_update(ctx);
17701788
// chat parameters

base_sampling2/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ struct common_init_result common_init_from_params(common_params & params) {
875875
return iparams;
876876
}
877877

878-
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
878+
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
879879
printf("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
880880
llama_model_free(model);
881881
return iparams;
@@ -982,7 +982,7 @@ struct common_init_result common_init_from_params(common_params & params) {
982982
if (llama_model_has_decoder(model)) {
983983
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
984984
}
985-
llama_kv_self_clear(lctx);
985+
llama_memory_clear(llama_get_memory(lctx), true);
986986
llama_synchronize(lctx);
987987
llama_perf_context_reset(lctx);
988988
llama_set_warmup(lctx, false);

base_sampling2/include/jsonParams.h

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,32 @@ static std::string extract_string_mod(std::string& text, std::string open, std::
4646
return "NULL";
4747
}
4848

49+
static void extract_logit_bias_strings(std::string& text, std::string open, std::string close, nlohmann::json& config) {
50+
size_t open_pos = text.rfind(open);
51+
size_t close_pos = text.rfind(close);
52+
if (open_pos != text.npos && close_pos != text.npos) {
53+
size_t diff = close_pos - open_pos - open.length();
54+
std::string extract = text.substr(open_pos + open.length(), diff);
55+
std::cout << "Extracting: " << extract << std::endl;
56+
text.replace(open_pos,diff + open.length() + close.length(),"");
57+
58+
if (extract.size() > 1) {
59+
int pos = 0;
60+
for (int i = 0; i < extract.size(); i++) {
61+
if (extract[i] == ',') {
62+
std::string part = extract.substr(pos, i);
63+
config["logit_bias_strings_exact"].push_back(part);
64+
std::cout << "Adding " << part << std::endl;
65+
pos = i+1;
66+
}
67+
}
68+
} else if (extract.size() > 0) {
69+
config["logit_bias_strings_exact"].push_back(extract);
70+
std::cout << "Adding " << extract << std::endl;
71+
}
72+
}
73+
}
74+
4975
static bool replace_string_mod(std::string& text, std::string target, std::string replacement) {
5076
size_t target_pos = text.rfind(target);
5177

@@ -621,9 +647,10 @@ static void getPerformanceParamsFromJson(nlohmann::json& config, common_params&
621647
if (checkJNum(config, "n_threads_sched_priority")) {
622648
int sched_priority = config["n_threads_sched_priority"];
623649
switch (sched_priority){
624-
case 1: params.cpuparams.priority = GGML_SCHED_PRIO_MEDIUM; break;
625-
case 2: params.cpuparams.priority = GGML_SCHED_PRIO_HIGH; break;
626-
case 3: params.cpuparams.priority = GGML_SCHED_PRIO_REALTIME; break;
650+
case 1: params.cpuparams.priority = GGML_SCHED_PRIO_LOW; break;
651+
case 2: params.cpuparams.priority = GGML_SCHED_PRIO_MEDIUM; break;
652+
case 3: params.cpuparams.priority = GGML_SCHED_PRIO_HIGH; break;
653+
case 4: params.cpuparams.priority = GGML_SCHED_PRIO_REALTIME; break;
627654
default: params.cpuparams.priority = GGML_SCHED_PRIO_NORMAL; break;
628655
}
629656
}
@@ -634,9 +661,10 @@ static void getPerformanceParamsFromJson(nlohmann::json& config, common_params&
634661
if (checkJNum(config, "n_threads_batch_sched_priority")) {
635662
int sched_priority = config["n_threads_batch_sched_priority"];
636663
switch (sched_priority){
637-
case 1: params.cpuparams_batch.priority = GGML_SCHED_PRIO_MEDIUM; break;
638-
case 2: params.cpuparams_batch.priority = GGML_SCHED_PRIO_HIGH; break;
639-
case 3: params.cpuparams_batch.priority = GGML_SCHED_PRIO_REALTIME; break;
664+
case 1: params.cpuparams_batch.priority = GGML_SCHED_PRIO_LOW; break;
665+
case 2: params.cpuparams_batch.priority = GGML_SCHED_PRIO_MEDIUM; break;
666+
case 3: params.cpuparams_batch.priority = GGML_SCHED_PRIO_HIGH; break;
667+
case 4: params.cpuparams_batch.priority = GGML_SCHED_PRIO_REALTIME; break;
640668
default: params.cpuparams_batch.priority = GGML_SCHED_PRIO_NORMAL; break;
641669
}
642670
}

base_sampling2/llama-addon.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ float xtc_percent = 0.0;
4646
int candidates_max = 0;
4747
int candidates_max_min_p = 0;
4848
std::string last_candidates = "NONE";
49+
std::vector<llama_token> last_candidates_logits;
4950

5051
int rx_total = 0;
5152
int rx_removed = 0;
@@ -154,13 +155,16 @@ static bool writeCandidatesToFile2vec(std::string path, std::vector<llama_token_
154155

155156
static std::string getFormattedCandidates(llama_token_data_array * candidates) {
156157
std::string text = "(" + std::to_string(candidates->size) + "): ";
158+
last_candidates_logits.clear();
157159
int zeroes = 0;
158160
for (size_t i = 0; i < candidates->size; ++i) {
159161
int chance = candidates->data[i].p * 100;
160162
int logit = candidates->data[i].logit;
161-
if (chance > 0 || candidates->size == 1) {
162-
text += " #" + std::to_string(i) +"[" + std::to_string(chance) + "%|" + std::to_string(logit) + "]";
163+
if (chance > 0 || candidates->size == 1) {
164+
text += " #" + std::to_string(i) +"[" + std::to_string(chance) + "%|" + std::to_string(logit) + "]";
163165
} else ++zeroes;
166+
167+
last_candidates_logits.push_back(candidates->data[i].id);
164168
}
165169
//if (zeroes > 0) text += "~" + std::to_string(zeroes);
166170

base_sampling2/master/ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
318318
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
319319
endif()
320320

321-
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
321+
string(TOUPPER "${POWER10_M}" POWER10_M_UPPER)
322+
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
322323
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
323324

324325
if (EXTRACTED_NUMBER GREATER_EQUAL 10)

base_sampling2/master/ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,8 +2430,6 @@ static bool ggml_thread_apply_priority(int32_t prio) {
24302430
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
24312431
// all our threads onto the first 4 cores which results in terrible performance with
24322432
// n_threads > 4
2433-
// MinGW doesn't support THREAD_POWER_THROTTLING_CURRENT_VERSION
2434-
// and THREAD_POWER_THROTTLING_EXECUTION_SPEED
24352433
#if !defined(__GNUC__) && _WIN32_WINNT >= 0x0602
24362434
THREAD_POWER_THROTTLING_STATE t;
24372435
ZeroMemory(&t, sizeof(t));

base_sampling2/master/ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8132,8 +8132,8 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
81328132
#define WKV_VECTOR_SIZE 4
81338133
#endif
81348134

8135-
int wkv_vector_size;
81368135
#ifdef WKV_VECTOR_SIZE
8136+
int wkv_vector_size;
81378137
#if defined(__ARM_FEATURE_SVE)
81388138
wkv_vector_size = svcntw();
81398139
#else
@@ -8348,8 +8348,8 @@ static void ggml_compute_forward_gla_f32(
83488348
#define GLA_VECTOR_SIZE 4
83498349
#endif
83508350

8351-
int gla_vector_size;
83528351
#ifdef GLA_VECTOR_SIZE
8352+
int gla_vector_size;
83538353
#if defined(__ARM_FEATURE_SVE)
83548354
gla_vector_size = svcntw();
83558355
#else

base_sampling2/master/ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ set(GGML_OPENCL_KERNELS
9595
sub
9696
sum_rows
9797
transpose
98+
concat
99+
tsembd
100+
upscale
101+
tanh
102+
pad
103+
repeat
98104
)
99105

100106
foreach (K ${GGML_OPENCL_KERNELS})

0 commit comments

Comments
 (0)