Skip to content

Commit b39aa8e

Browse files
committed
imatrix: be able to specify the name of the output tensor
picked from ik_llama.cpp, a llama_cpp fork maintained by Iwan Kawrakow
1 parent 6320b59 commit b39aa8e

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

common/common.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
12961296
params.process_output = true;
12971297
return true;
12981298
}
1299+
if (arg == "--output-tensor-name") {
1300+
if (++i >= argc) {
1301+
invalid_param = true;
1302+
return true;
1303+
}
1304+
params.output_tensor_name = argv[i];
1305+
return true;
1306+
}
12991307
if (arg == "--no-ppl") {
13001308
params.compute_ppl = false;
13011309
return true;

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ struct gpt_params {
246246

247247
// imatrix params
248248
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
249+
250+
std::string output_tensor_name = "output.weight"; // name of the output tensor
249251

250252
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
251253
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations

examples/imatrix/imatrix.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
8383
if (t->op != GGML_OP_MUL_MAT) return false;
8484
// why are small batches ignored (<16 tokens)?
8585
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
86-
if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false;
86+
printf("wname = %s\n", wname.c_str());
87+
if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == m_params.output_tensor_name))) return false;
8788
return true;
8889
}
8990

0 commit comments

Comments
 (0)