Skip to content

Commit 554593d

Browse files
committed
Variable scopes are fun
1 parent 0b30188 commit 554593d

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

tools/main/main.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <sstream>
1515
#include <string>
1616
#include <vector>
17+
#include <mutex>
1718

1819
// Forward declarations for internal cache access
1920
struct llama_memory_hybrid;
@@ -92,6 +93,7 @@ static void sigint_handler(int signo) {
9293
struct callback_data {
9394
std::vector<uint8_t> data;
9495
std::map<std::string, int32_t> tensors;
96+
std::mutex mutex;
9597
};
9698

9799

@@ -210,6 +212,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
210212

211213
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
212214
auto * cb_data = (callback_data *) user_data;
215+
std::lock_guard<std::mutex> lock(cb_data->mutex);
213216

214217
const struct ggml_tensor * src0 = t->src[0];
215218
const struct ggml_tensor * src1 = t->src[1];
@@ -241,16 +244,18 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
241244

242245
if (!ggml_is_quantized(t->type)) {
243246
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
244-
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
245-
if (std::string(t->name).substr(0, std::string("post_moe-").size()) == "post_moe-" ||
246-
std::string(t->name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
247-
if (cb_data->tensors.count(t->name) == 0) {
248-
cb_data->tensors[t->name] = 1;
247+
std::string tensor_name(t->name);
248+
if (std::string(tensor_name).substr(0, std::string("post_moe-").size()) == "post_moe-" ||
249+
std::string(tensor_name).substr(0, std::string("state_1d-").size()) == "state_1d-") {
250+
251+
if (cb_data->tensors.count(tensor_name) == 0) {
252+
cb_data->tensors[tensor_name] = 1;
249253
} else {
250-
cb_data->tensors[t->name]++;
254+
cb_data->tensors[tensor_name]++;
251255
}
252-
save_tensor(t, data, (std::string(t->name) + "_" + std::to_string(cb_data->tensors[t->name]) + ".bin").c_str());
256+
save_tensor(t, data, (tensor_name + "_" + std::to_string(cb_data->tensors[t->name]) + ".bin").c_str());
253257
}
258+
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
254259
}
255260

256261
return true;
@@ -312,9 +317,9 @@ int main(int argc, char ** argv) {
312317
std::vector<common_chat_msg> chat_msgs;
313318

314319
// load the model and apply lora adapter, if any
320+
callback_data cb_data;
315321
if (params.n_predict > 0 && params.n_predict < 50) {
316322
// enable debug prints if we print small number of tokens
317-
callback_data cb_data;
318323
params.cb_eval = ggml_debug;
319324
params.cb_eval_user_data = &cb_data;
320325
}

0 commit comments

Comments
 (0)